|
import torch |
|
from speechbrain.inference.interfaces import Pretrained |
|
|
|
|
|
class CustomEncoderClassifier(Pretrained): |
|
"""A ready-to-use class for utterance-level classification (e.g, speaker-id, |
|
language-id, emotion recognition, keyword spotting, etc). |
|
The class assumes that an self-supervised encoder like wav2vec2/hubert and a classifier model |
|
are defined in the yaml file. If you want to |
|
convert the predicted index into a corresponding text label, please |
|
provide the path of the label_encoder in a variable called 'lab_encoder_file' |
|
within the yaml. |
|
The class can be used either to run only the encoder (encode_batch()) to |
|
extract embeddings or to run a classification step (classify_batch()). |
|
``` |
|
Example |
|
------- |
|
>>> import torchaudio |
|
>>> from speechbrain.pretrained import EncoderClassifier |
|
>>> # Model is downloaded from the speechbrain HuggingFace repo |
|
>>> tmpdir = getfixture("tmpdir") |
|
>>> classifier = EncoderClassifier.from_hparams( |
|
... source="speechbrain/spkrec-ecapa-voxceleb", |
|
... savedir=tmpdir, |
|
... ) |
|
>>> # Compute embeddings |
|
>>> signal, fs = torchaudio.load("samples/audio_samples/example1.wav") |
|
>>> embeddings = classifier.encode_batch(signal) |
|
>>> # Classification |
|
>>> prediction = classifier .classify_batch(signal) |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.similarity = torch.nn.CosineSimilarity(dim=-1, eps=1e-6) |
|
|
|
def encode_batch(self, wavs, wav_lens=None, normalize=False): |
|
"""Encodes the input audio into a single vector embedding. |
|
The waveforms should already be in the model's desired format. |
|
You can call: |
|
``normalized = <this>.normalizer(signal, sample_rate)`` |
|
to get a correctly converted signal in most cases. |
|
Arguments |
|
--------- |
|
wavs : torch.tensor |
|
Batch of waveforms [batch, time, channels] or [batch, time] |
|
depending on the model. Make sure the sample rate is fs=16000 Hz. |
|
wav_lens : torch.tensor |
|
Lengths of the waveforms relative to the longest one in the |
|
batch, tensor of shape [batch]. The longest one should have |
|
relative length 1.0 and others len(waveform) / max_length. |
|
Used for ignoring padding. |
|
normalize : bool |
|
If True, it normalizes the embeddings with the statistics |
|
contained in mean_var_norm_emb. |
|
Returns |
|
------- |
|
torch.tensor |
|
The encoded batch |
|
""" |
|
|
|
if len(wavs.shape) == 1: |
|
wavs = wavs.unsqueeze(0) |
|
|
|
|
|
if wav_lens is None: |
|
wav_lens = torch.ones(wavs.shape[0], device=self.device) |
|
|
|
|
|
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) |
|
wavs = wavs.float() |
|
|
|
with torch.no_grad(): |
|
self.hparams.codec.to(self.device).eval() |
|
tokens, _, _ = self.hparams.codec( |
|
wavs, wav_lens, **self.hparams.tokenizer_config |
|
) |
|
embeddings = self.mods.discrete_embedding_layer(tokens) |
|
att_w = self.mods.attention_mlp(embeddings) |
|
feats = torch.matmul(att_w.transpose(2, -1), embeddings).squeeze(-2) |
|
embeddings = self.mods.embedding_model(feats, wav_lens) |
|
return embeddings.squeeze(1) |
|
|
|
|
|
def verify_batch( |
|
self, wavs1, wavs2, wav1_lens=None, wav2_lens=None, threshold=0.25 |
|
): |
|
"""Performs speaker verification with cosine distance. |
|
|
|
It returns the score and the decision (0 different speakers, |
|
1 same speakers). |
|
|
|
Arguments |
|
--------- |
|
wavs1 : Torch.Tensor |
|
torch.Tensor containing the speech waveform1 (batch, time). |
|
Make sure the sample rate is fs=16000 Hz. |
|
wavs2 : Torch.Tensor |
|
torch.Tensor containing the speech waveform2 (batch, time). |
|
Make sure the sample rate is fs=16000 Hz. |
|
wav1_lens : Torch.Tensor |
|
torch.Tensor containing the relative length for each sentence |
|
in the length (e.g., [0.8 0.6 1.0]) |
|
wav2_lens : Torch.Tensor |
|
torch.Tensor containing the relative length for each sentence |
|
in the length (e.g., [0.8 0.6 1.0]) |
|
threshold : Float |
|
Threshold applied to the cosine distance to decide if the |
|
speaker is different (0) or the same (1). |
|
|
|
Returns |
|
------- |
|
score |
|
The score associated to the binary verification output |
|
(cosine distance). |
|
prediction |
|
The prediction is 1 if the two signals in input are from the same |
|
speaker and 0 otherwise. |
|
""" |
|
emb1 = self.encode_batch(wavs1, wav1_lens, normalize=False) |
|
emb2 = self.encode_batch(wavs2, wav2_lens, normalize=False) |
|
score = self.similarity(emb1, emb2) |
|
return score, score > threshold |
|
|
|
def verify_files(self, path_x, path_y, **kwargs): |
|
"""Speaker verification with cosine distance |
|
|
|
Returns the score and the decision (0 different speakers, |
|
1 same speakers). |
|
|
|
Arguments |
|
--------- |
|
path_x : str |
|
Path to file x |
|
path_y : str |
|
Path to file y |
|
**kwargs : dict |
|
Arguments to ``load_audio`` |
|
|
|
Returns |
|
------- |
|
score |
|
The score associated to the binary verification output |
|
(cosine distance). |
|
prediction |
|
The prediction is 1 if the two signals in input are from the same |
|
speaker and 0 otherwise. |
|
""" |
|
waveform_x = self.load_audio(path_x, **kwargs) |
|
waveform_y = self.load_audio(path_y, **kwargs) |
|
|
|
batch_x = waveform_x.unsqueeze(0) |
|
batch_y = waveform_y.unsqueeze(0) |
|
|
|
score, decision = self.verify_batch(batch_x, batch_y) |
|
|
|
return score[0], decision[0] |