File size: 2,105 Bytes
899cf32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from typing import Dict, Any

import torch
import librosa
import numpy as np
from datasets import Dataset

from ..cloning.model import CloningModel
from ..transcriber.model import TranscriberModel


def prepare_dataset(example: Dict[str, Any], model: CloningModel) -> Dict[str, Any]:
    """
    Prepare a single example for training
    """
    # feature extraction and tokenization
    processed_example = model.processor(
        text=example["normalized_text"],
        audio_target=example["audio"]["array"],
        sampling_rate=16000,
        return_attention_mask=False,
    )

    # strip off the batch dimension
    if len(torch.tensor(processed_example['input_ids']).shape) > 1:
        processed_example['input_ids'] = processed_example['input_ids'][0]

    processed_example["labels"] = processed_example["labels"][0]

    # use SpeechBrain to obtain x-vector
    processed_example["speaker_embeddings"] = model.create_speaker_embedding(
        torch.tensor(example["audio"]["array"])
    ).numpy()

    return processed_example


def get_cloning_dataset(input_audio_path: str,
                        transcriber_model: TranscriberModel,
                        cloning_model: CloningModel,
                        sampling_rate: int = 16000,
                        window_size_secs: int = 5) -> Dataset:
    """
    Create dataset by transcribing an audio file using a pretrained Wav2Vec2 model.
    """
    speech_array, _ = librosa.load(input_audio_path, sr=sampling_rate)

    # split a waveform into splits of 5 secs each
    speech_arrays = np.split(speech_array, range(0, len(speech_array), window_size_secs * sampling_rate))[1:]
    texts = [transcriber_model.forward(speech_array, sampling_rate=sampling_rate)
             for speech_array in speech_arrays]

    dataset = Dataset.from_list([
        {'audio': {'array': speech_arrays[i]}, 'normalized_text': texts[i]}
        for i in range(len(speech_arrays))]
    )

    dataset = dataset.map(
        prepare_dataset, fn_kwargs={'model': cloning_model},
        remove_columns=dataset.column_names,
    )

    return dataset