mikeross's picture
Duplicate from nateraw/deepafx-st
c983126
import torch
from deepafx_st.models.mobilenetv2 import MobileNetV2
from deepafx_st.models.efficient_net import EfficientNet
class SpectralEncoder(torch.nn.Module):
def __init__(
self,
num_params,
sample_rate,
encoder_model="mobilenet_v2",
embed_dim=1028,
width_mult=1,
min_level_db=-80,
):
"""Encoder operating on spectrograms.
Args:
num_params (int): Number of processor parameters to generate.
sample_rate (float): Audio sample rate for computing melspectrogram.
encoder_model (str, optional): Encoder model architecture. Default: "mobilenet_v2"
embed_dim (int, optional): Dimentionality of the encoder representations.
width_mult (int, optional): Encoder size. Default: 1
min_level_db (float, optional): Minimal dB value for the spectrogram. Default: -80
"""
super().__init__()
self.num_params = num_params
self.sample_rate = sample_rate
self.encoder_model = encoder_model
self.embed_dim = embed_dim
self.width_mult = width_mult
self.min_level_db = min_level_db
# load model from torch.hub
if encoder_model == "mobilenet_v2":
self.encoder = MobileNetV2(embed_dim=embed_dim, width_mult=width_mult)
elif encoder_model == "efficient_net":
self.encoder = EfficientNet.from_name(
"efficientnet-b2",
in_channels=1,
image_size=(128, 65),
include_top=False,
)
self.embedding_projection = torch.nn.Conv2d(
in_channels=1408,
out_channels=embed_dim,
kernel_size=(1, 1),
stride=(1, 1),
padding=(0, 0),
bias=True,
)
else:
raise ValueError(f"Invalid encoder_model: {encoder_model}.")
self.window = torch.nn.Parameter(torch.hann_window(4096))
def forward(self, x):
"""
Args:
x (Tensor): Input waveform of shape [batch x channels x samples]
Returns:
e (Tensor): Latent embedding produced by Encoder. [batch x embed_dim]
"""
bs, chs, samp = x.size()
# compute spectrogram of waveform
X = torch.stft(
x.view(bs, -1),
4096,
2048,
window=self.window,
return_complex=True,
)
X_db = torch.pow(X.abs() + 1e-8, 0.3)
X_db_norm = X_db
# standardize (0, 1) 0.322970 0.278452
X_db_norm -= 0.322970
X_db_norm /= 0.278452
X_db_norm = X_db_norm.unsqueeze(1).permute(0, 1, 3, 2)
if self.encoder_model == "mobilenet_v2":
# repeat channels by 3 to fit vision model
X_db_norm = X_db_norm.repeat(1, 3, 1, 1)
# pass melspectrogram through encoder
e = self.encoder(X_db_norm)
# apply avg pooling across time for encoder embeddings
e = torch.nn.functional.adaptive_avg_pool2d(e, 1).reshape(e.shape[0], -1)
# normalize by L2 norm
norm = torch.norm(e, p=2, dim=-1, keepdim=True)
e_norm = e / norm
elif self.encoder_model == "efficient_net":
# Efficient Net internal downsamples by 32 on time and freq axis, then average pools the rest
e = self.encoder(X_db_norm)
# Adding 1x1 conv to project down or up to the requested embedding size
e = self.embedding_projection(e)
e = torch.squeeze(e, dim=3)
e = torch.squeeze(e, dim=2)
# normalize by L2 norm
norm = torch.norm(e, p=2, dim=-1, keepdim=True)
e_norm = e / norm
return e_norm