discrete_wavlm_spk_rec_ecapatdn / custom_interface.py
poonehmousavi's picture
Upload 10 files
89a1ae3 verified
history blame
6.16 kB
import torch
from speechbrain.inference.interfaces import Pretrained
class CustomEncoderClassifier(Pretrained):
"""A ready-to-use class for utterance-level classification (e.g, speaker-id,
language-id, emotion recognition, keyword spotting, etc).
The class assumes that an self-supervised encoder like wav2vec2/hubert and a classifier model
are defined in the yaml file. If you want to
convert the predicted index into a corresponding text label, please
provide the path of the label_encoder in a variable called 'lab_encoder_file'
within the yaml.
The class can be used either to run only the encoder (encode_batch()) to
extract embeddings or to run a classification step (classify_batch()).
>>> import torchaudio
>>> from speechbrain.pretrained import EncoderClassifier
>>> # Model is downloaded from the speechbrain HuggingFace repo
>>> tmpdir = getfixture("tmpdir")
>>> classifier = EncoderClassifier.from_hparams(
... source="speechbrain/spkrec-ecapa-voxceleb",
... savedir=tmpdir,
... )
>>> # Compute embeddings
>>> signal, fs = torchaudio.load("samples/audio_samples/example1.wav")
>>> embeddings = classifier.encode_batch(signal)
>>> # Classification
>>> prediction = classifier .classify_batch(signal)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.similarity = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
def encode_batch(self, wavs, wav_lens=None, normalize=False):
"""Encodes the input audio into a single vector embedding.
The waveforms should already be in the model's desired format.
You can call:
``normalized = <this>.normalizer(signal, sample_rate)``
to get a correctly converted signal in most cases.
wavs : torch.tensor
Batch of waveforms [batch, time, channels] or [batch, time]
depending on the model. Make sure the sample rate is fs=16000 Hz.
wav_lens : torch.tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
normalize : bool
If True, it normalizes the embeddings with the statistics
contained in mean_var_norm_emb.
The encoded batch
# Manage single waveforms in input
if len(wavs.shape) == 1:
wavs = wavs.unsqueeze(0)
# Assign full length if wav_lens is not assigned
if wav_lens is None:
wav_lens = torch.ones(wavs.shape[0], device=self.device)
# Storing waveform in the specified device
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
wavs = wavs.float()
with torch.no_grad():
tokens, _, _ = self.hparams.codec(
wavs, wav_lens, **self.hparams.tokenizer_config
embeddings = self.mods.discrete_embedding_layer(tokens)
att_w = self.mods.attention_mlp(embeddings)
feats = torch.matmul(att_w.transpose(2, -1), embeddings).squeeze(-2)
embeddings = self.mods.embedding_model(feats, wav_lens)
return embeddings.squeeze(1)
def verify_batch(
self, wavs1, wavs2, wav1_lens=None, wav2_lens=None, threshold=0.25
"""Performs speaker verification with cosine distance.
It returns the score and the decision (0 different speakers,
1 same speakers).
wavs1 : Torch.Tensor
torch.Tensor containing the speech waveform1 (batch, time).
Make sure the sample rate is fs=16000 Hz.
wavs2 : Torch.Tensor
torch.Tensor containing the speech waveform2 (batch, time).
Make sure the sample rate is fs=16000 Hz.
wav1_lens : Torch.Tensor
torch.Tensor containing the relative length for each sentence
in the length (e.g., [0.8 0.6 1.0])
wav2_lens : Torch.Tensor
torch.Tensor containing the relative length for each sentence
in the length (e.g., [0.8 0.6 1.0])
threshold : Float
Threshold applied to the cosine distance to decide if the
speaker is different (0) or the same (1).
The score associated to the binary verification output
(cosine distance).
The prediction is 1 if the two signals in input are from the same
speaker and 0 otherwise.
emb1 = self.encode_batch(wavs1, wav1_lens, normalize=False)
emb2 = self.encode_batch(wavs2, wav2_lens, normalize=False)
score = self.similarity(emb1, emb2)
return score, score > threshold
def verify_files(self, path_x, path_y, **kwargs):
"""Speaker verification with cosine distance
Returns the score and the decision (0 different speakers,
1 same speakers).
path_x : str
Path to file x
path_y : str
Path to file y
**kwargs : dict
Arguments to ``load_audio``
The score associated to the binary verification output
(cosine distance).
The prediction is 1 if the two signals in input are from the same
speaker and 0 otherwise.
waveform_x = self.load_audio(path_x, **kwargs)
waveform_y = self.load_audio(path_y, **kwargs)
# Fake batches:
batch_x = waveform_x.unsqueeze(0)
batch_y = waveform_y.unsqueeze(0)
# Verify:
score, decision = self.verify_batch(batch_x, batch_y)
# Squeeze:
return score[0], decision[0]