import torch from speechbrain.inference.interfaces import Pretrained class CustomEncoderClassifier(Pretrained): """A ready-to-use class for utterance-level classification (e.g, speaker-id, language-id, emotion recognition, keyword spotting, etc). The class assumes that an self-supervised encoder like wav2vec2/hubert and a classifier model are defined in the yaml file. If you want to convert the predicted index into a corresponding text label, please provide the path of the label_encoder in a variable called 'lab_encoder_file' within the yaml. The class can be used either to run only the encoder (encode_batch()) to extract embeddings or to run a classification step (classify_batch()). ``` Example ------- >>> import torchaudio >>> from speechbrain.pretrained import EncoderClassifier >>> # Model is downloaded from the speechbrain HuggingFace repo >>> tmpdir = getfixture("tmpdir") >>> classifier = EncoderClassifier.from_hparams( ... source="speechbrain/spkrec-ecapa-voxceleb", ... savedir=tmpdir, ... ) >>> # Compute embeddings >>> signal, fs = torchaudio.load("samples/audio_samples/example1.wav") >>> embeddings = classifier.encode_batch(signal) >>> # Classification >>> prediction = classifier .classify_batch(signal) """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.similarity = torch.nn.CosineSimilarity(dim=-1, eps=1e-6) def encode_batch(self, wavs, wav_lens=None, normalize=False): """Encodes the input audio into a single vector embedding. The waveforms should already be in the model's desired format. You can call: ``normalized = .normalizer(signal, sample_rate)`` to get a correctly converted signal in most cases. Arguments --------- wavs : torch.tensor Batch of waveforms [batch, time, channels] or [batch, time] depending on the model. Make sure the sample rate is fs=16000 Hz. wav_lens : torch.tensor Lengths of the waveforms relative to the longest one in the batch, tensor of shape [batch]. The longest one should have relative length 1.0 and others len(waveform) / max_length. Used for ignoring padding. normalize : bool If True, it normalizes the embeddings with the statistics contained in mean_var_norm_emb. Returns ------- torch.tensor The encoded batch """ # Manage single waveforms in input if len(wavs.shape) == 1: wavs = wavs.unsqueeze(0) # Assign full length if wav_lens is not assigned if wav_lens is None: wav_lens = torch.ones(wavs.shape[0], device=self.device) # Storing waveform in the specified device wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) wavs = wavs.float() with torch.no_grad(): self.hparams.codec.to(self.device).eval() tokens, _, _ = self.hparams.codec( wavs, wav_lens, **self.hparams.tokenizer_config ) embeddings = self.mods.discrete_embedding_layer(tokens) att_w = self.mods.attention_mlp(embeddings) feats = torch.matmul(att_w.transpose(2, -1), embeddings).squeeze(-2) embeddings = self.mods.embedding_model(feats, wav_lens) return embeddings.squeeze(1) def verify_batch( self, wavs1, wavs2, wav1_lens=None, wav2_lens=None, threshold=0.25 ): """Performs speaker verification with cosine distance. It returns the score and the decision (0 different speakers, 1 same speakers). Arguments --------- wavs1 : Torch.Tensor torch.Tensor containing the speech waveform1 (batch, time). Make sure the sample rate is fs=16000 Hz. wavs2 : Torch.Tensor torch.Tensor containing the speech waveform2 (batch, time). Make sure the sample rate is fs=16000 Hz. wav1_lens : Torch.Tensor torch.Tensor containing the relative length for each sentence in the length (e.g., [0.8 0.6 1.0]) wav2_lens : Torch.Tensor torch.Tensor containing the relative length for each sentence in the length (e.g., [0.8 0.6 1.0]) threshold : Float Threshold applied to the cosine distance to decide if the speaker is different (0) or the same (1). Returns ------- score The score associated to the binary verification output (cosine distance). prediction The prediction is 1 if the two signals in input are from the same speaker and 0 otherwise. """ emb1 = self.encode_batch(wavs1, wav1_lens, normalize=False) emb2 = self.encode_batch(wavs2, wav2_lens, normalize=False) score = self.similarity(emb1, emb2) return score, score > threshold def verify_files(self, path_x, path_y, **kwargs): """Speaker verification with cosine distance Returns the score and the decision (0 different speakers, 1 same speakers). Arguments --------- path_x : str Path to file x path_y : str Path to file y **kwargs : dict Arguments to ``load_audio`` Returns ------- score The score associated to the binary verification output (cosine distance). prediction The prediction is 1 if the two signals in input are from the same speaker and 0 otherwise. """ waveform_x = self.load_audio(path_x, **kwargs) waveform_y = self.load_audio(path_y, **kwargs) # Fake batches: batch_x = waveform_x.unsqueeze(0) batch_y = waveform_y.unsqueeze(0) # Verify: score, decision = self.verify_batch(batch_x, batch_y) # Squeeze: return score[0], decision[0]