discrete_wavlm_spk_rec_ecapatdn / custom_interface.py
poonehmousavi's picture
Upload 10 files
89a1ae3 verified
raw
history blame
6.16 kB
import torch
from speechbrain.inference.interfaces import Pretrained
class CustomEncoderClassifier(Pretrained):
"""A ready-to-use class for utterance-level classification (e.g, speaker-id,
language-id, emotion recognition, keyword spotting, etc).
The class assumes that an self-supervised encoder like wav2vec2/hubert and a classifier model
are defined in the yaml file. If you want to
convert the predicted index into a corresponding text label, please
provide the path of the label_encoder in a variable called 'lab_encoder_file'
within the yaml.
The class can be used either to run only the encoder (encode_batch()) to
extract embeddings or to run a classification step (classify_batch()).
```
Example
-------
>>> import torchaudio
>>> from speechbrain.pretrained import EncoderClassifier
>>> # Model is downloaded from the speechbrain HuggingFace repo
>>> tmpdir = getfixture("tmpdir")
>>> classifier = EncoderClassifier.from_hparams(
... source="speechbrain/spkrec-ecapa-voxceleb",
... savedir=tmpdir,
... )
>>> # Compute embeddings
>>> signal, fs = torchaudio.load("samples/audio_samples/example1.wav")
>>> embeddings = classifier.encode_batch(signal)
>>> # Classification
>>> prediction = classifier .classify_batch(signal)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.similarity = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
def encode_batch(self, wavs, wav_lens=None, normalize=False):
"""Encodes the input audio into a single vector embedding.
The waveforms should already be in the model's desired format.
You can call:
``normalized = <this>.normalizer(signal, sample_rate)``
to get a correctly converted signal in most cases.
Arguments
---------
wavs : torch.tensor
Batch of waveforms [batch, time, channels] or [batch, time]
depending on the model. Make sure the sample rate is fs=16000 Hz.
wav_lens : torch.tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
normalize : bool
If True, it normalizes the embeddings with the statistics
contained in mean_var_norm_emb.
Returns
-------
torch.tensor
The encoded batch
"""
# Manage single waveforms in input
if len(wavs.shape) == 1:
wavs = wavs.unsqueeze(0)
# Assign full length if wav_lens is not assigned
if wav_lens is None:
wav_lens = torch.ones(wavs.shape[0], device=self.device)
# Storing waveform in the specified device
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
wavs = wavs.float()
with torch.no_grad():
self.hparams.codec.to(self.device).eval()
tokens, _, _ = self.hparams.codec(
wavs, wav_lens, **self.hparams.tokenizer_config
)
embeddings = self.mods.discrete_embedding_layer(tokens)
att_w = self.mods.attention_mlp(embeddings)
feats = torch.matmul(att_w.transpose(2, -1), embeddings).squeeze(-2)
embeddings = self.mods.embedding_model(feats, wav_lens)
return embeddings.squeeze(1)
def verify_batch(
self, wavs1, wavs2, wav1_lens=None, wav2_lens=None, threshold=0.25
):
"""Performs speaker verification with cosine distance.
It returns the score and the decision (0 different speakers,
1 same speakers).
Arguments
---------
wavs1 : Torch.Tensor
torch.Tensor containing the speech waveform1 (batch, time).
Make sure the sample rate is fs=16000 Hz.
wavs2 : Torch.Tensor
torch.Tensor containing the speech waveform2 (batch, time).
Make sure the sample rate is fs=16000 Hz.
wav1_lens : Torch.Tensor
torch.Tensor containing the relative length for each sentence
in the length (e.g., [0.8 0.6 1.0])
wav2_lens : Torch.Tensor
torch.Tensor containing the relative length for each sentence
in the length (e.g., [0.8 0.6 1.0])
threshold : Float
Threshold applied to the cosine distance to decide if the
speaker is different (0) or the same (1).
Returns
-------
score
The score associated to the binary verification output
(cosine distance).
prediction
The prediction is 1 if the two signals in input are from the same
speaker and 0 otherwise.
"""
emb1 = self.encode_batch(wavs1, wav1_lens, normalize=False)
emb2 = self.encode_batch(wavs2, wav2_lens, normalize=False)
score = self.similarity(emb1, emb2)
return score, score > threshold
def verify_files(self, path_x, path_y, **kwargs):
"""Speaker verification with cosine distance
Returns the score and the decision (0 different speakers,
1 same speakers).
Arguments
---------
path_x : str
Path to file x
path_y : str
Path to file y
**kwargs : dict
Arguments to ``load_audio``
Returns
-------
score
The score associated to the binary verification output
(cosine distance).
prediction
The prediction is 1 if the two signals in input are from the same
speaker and 0 otherwise.
"""
waveform_x = self.load_audio(path_x, **kwargs)
waveform_y = self.load_audio(path_y, **kwargs)
# Fake batches:
batch_x = waveform_x.unsqueeze(0)
batch_y = waveform_y.unsqueeze(0)
# Verify:
score, decision = self.verify_batch(batch_x, batch_y)
# Squeeze:
return score[0], decision[0]