discrete_wavlm_spk_rec_ecapatdn / custom_interface.py
poonehmousavi's picture
Update custom_interface.py
04455e3 verified
raw
history blame
9.35 kB
import torch
from speechbrain.inference.interfaces import Pretrained
class AttentionMLP(torch.nn.Module):
def __init__(self, input_dim, hidden_dim):
super(AttentionMLP, self).__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(input_dim, hidden_dim),
torch.nn.ReLU(),
torch.nn.Linear(hidden_dim, 1, bias=False),
)
def forward(self, x):
x = self.layers(x)
att_w = torch.nn.functional.softmax(x, dim=2)
return att_w
class Discrete_EmbeddingLayer(torch.nn.Module):
"""This class handles embedding layers for discrete tokens.
Arguments
---------
num_codebooks: int ,
number of codebooks of the tokenizer.
vocab_size : int,
size of the dictionary of embeddings
emb_dim: int ,
the size of each embedding vector
pad_index: int (default: 0),
If specified, the entries at padding_idx do not contribute to the gradient.
init: boolean (default: False):
If set to True, init the embedding with the tokenizer embedding otherwise init randomly.
freeze: boolean (default: False)
If True, the embedding is frozen. If False, the model will be trained
alongside with the rest of the pipeline.
Example
-------
>>> from speechbrain.lobes.models.huggingface_transformers.encodec import Encodec
>>> model_hub = "facebook/encodec_24khz"
>>> save_path = "savedir"
>>> model = Encodec(model_hub, save_path)
>>> audio = torch.randn(4, 1000)
>>> length = torch.tensor([1.0, .5, .75, 1.0])
>>> tokens, emb = model.encode(audio, length)
>>> print(tokens.shape)
torch.Size([4, 4, 2])
>>> emb= Discrete_EmbeddingLayer(2, 1024, 1024)
>>> in_emb = emb(tokens)
>>> print(in_emb.shape)
torch.Size([4, 4, 2, 1024])
"""
def __init__(
self,
num_codebooks,
vocab_size,
emb_dim,
pad_index=0,
init=False,
freeze=False,
):
super(Discrete_EmbeddingLayer, self).__init__()
self.vocab_size = vocab_size
self.num_codebooks = num_codebooks
self.freeze = freeze
self.embedding = torch.nn.Embedding(
num_codebooks * vocab_size, emb_dim
).requires_grad_(not self.freeze)
self.init = init
def init_embedding(self, weights):
with torch.no_grad():
self.embedding.weight = torch.nn.Parameter(weights)
def forward(self, in_tokens):
"""Computes the embedding for discrete tokens.
a sample.
Arguments
---------
in_tokens : torch.Tensor
A (Batch x Time x num_codebooks)
audio sample
Returns
-------
in_embs : torch.Tensor
"""
with torch.set_grad_enabled(not self.freeze):
# Add unique token IDs across diffrent codebooks by adding num_codebooks * vocab_size
in_tokens += torch.arange(
0,
self.num_codebooks * self.vocab_size,
self.vocab_size,
device=in_tokens.device,
)
# Forward Pass to embedding and
in_embs = self.embedding(in_tokens)
return in_embs
class CustomEncoderClassifier(Pretrained):
"""A ready-to-use class for utterance-level classification (e.g, speaker-id,
language-id, emotion recognition, keyword spotting, etc).
The class assumes that an self-supervised encoder like wav2vec2/hubert and a classifier model
are defined in the yaml file. If you want to
convert the predicted index into a corresponding text label, please
provide the path of the label_encoder in a variable called 'lab_encoder_file'
within the yaml.
The class can be used either to run only the encoder (encode_batch()) to
extract embeddings or to run a classification step (classify_batch()).
```
Example
-------
>>> import torchaudio
>>> from speechbrain.pretrained import EncoderClassifier
>>> # Model is downloaded from the speechbrain HuggingFace repo
>>> tmpdir = getfixture("tmpdir")
>>> classifier = EncoderClassifier.from_hparams(
... source="speechbrain/spkrec-ecapa-voxceleb",
... savedir=tmpdir,
... )
>>> # Compute embeddings
>>> signal, fs = torchaudio.load("samples/audio_samples/example1.wav")
>>> embeddings = classifier.encode_batch(signal)
>>> # Classification
>>> prediction = classifier .classify_batch(signal)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.similarity = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
def encode_batch(self, wavs, wav_lens=None, normalize=False):
"""Encodes the input audio into a single vector embedding.
The waveforms should already be in the model's desired format.
You can call:
``normalized = <this>.normalizer(signal, sample_rate)``
to get a correctly converted signal in most cases.
Arguments
---------
wavs : torch.tensor
Batch of waveforms [batch, time, channels] or [batch, time]
depending on the model. Make sure the sample rate is fs=16000 Hz.
wav_lens : torch.tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
normalize : bool
If True, it normalizes the embeddings with the statistics
contained in mean_var_norm_emb.
Returns
-------
torch.tensor
The encoded batch
"""
# Manage single waveforms in input
if len(wavs.shape) == 1:
wavs = wavs.unsqueeze(0)
# Assign full length if wav_lens is not assigned
if wav_lens is None:
wav_lens = torch.ones(wavs.shape[0], device=self.device)
# Storing waveform in the specified device
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
wavs = wavs.float()
with torch.no_grad():
self.hparams.codec.to(self.device).eval()
tokens, _, _ = self.hparams.codec(
wavs, wav_lens, **self.hparams.tokenizer_config
)
embeddings = self.mods.discrete_embedding_layer(tokens)
att_w = self.mods.attention_mlp(embeddings)
feats = torch.matmul(att_w.transpose(2, -1), embeddings).squeeze(-2)
embeddings = self.mods.embedding_model(feats, wav_lens)
return embeddings.squeeze(1)
def verify_batch(
self, wavs1, wavs2, wav1_lens=None, wav2_lens=None, threshold=0.25
):
"""Performs speaker verification with cosine distance.
It returns the score and the decision (0 different speakers,
1 same speakers).
Arguments
---------
wavs1 : Torch.Tensor
torch.Tensor containing the speech waveform1 (batch, time).
Make sure the sample rate is fs=16000 Hz.
wavs2 : Torch.Tensor
torch.Tensor containing the speech waveform2 (batch, time).
Make sure the sample rate is fs=16000 Hz.
wav1_lens : Torch.Tensor
torch.Tensor containing the relative length for each sentence
in the length (e.g., [0.8 0.6 1.0])
wav2_lens : Torch.Tensor
torch.Tensor containing the relative length for each sentence
in the length (e.g., [0.8 0.6 1.0])
threshold : Float
Threshold applied to the cosine distance to decide if the
speaker is different (0) or the same (1).
Returns
-------
score
The score associated to the binary verification output
(cosine distance).
prediction
The prediction is 1 if the two signals in input are from the same
speaker and 0 otherwise.
"""
emb1 = self.encode_batch(wavs1, wav1_lens, normalize=False)
emb2 = self.encode_batch(wavs2, wav2_lens, normalize=False)
score = self.similarity(emb1, emb2)
return score, score > threshold
def verify_files(self, path_x, path_y, **kwargs):
"""Speaker verification with cosine distance
Returns the score and the decision (0 different speakers,
1 same speakers).
Arguments
---------
path_x : str
Path to file x
path_y : str
Path to file y
**kwargs : dict
Arguments to ``load_audio``
Returns
-------
score
The score associated to the binary verification output
(cosine distance).
prediction
The prediction is 1 if the two signals in input are from the same
speaker and 0 otherwise.
"""
waveform_x = self.load_audio(path_x, **kwargs)
waveform_y = self.load_audio(path_y, **kwargs)
# Fake batches:
batch_x = waveform_x.unsqueeze(0)
batch_y = waveform_y.unsqueeze(0)
# Verify:
score, decision = self.verify_batch(batch_x, batch_y)
# Squeeze:
return score[0], decision[0]