File size: 8,318 Bytes
3a98934
33c23f4
 
1d21972
33c23f4
 
 
a1c7d58
33c23f4
166aa6c
5195c9e
33c23f4
5195c9e
 
 
33c23f4
206b5fc
 
5195c9e
206b5fc
 
a1c7d58
5195c9e
206b5fc
 
33c23f4
 
abdf62b
33c23f4
 
abdf62b
33c23f4
 
 
 
 
 
 
 
 
abdf62b
33c23f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1c7d58
abdf62b
33c23f4
a1c7d58
abdf62b
33c23f4
abdf62b
33c23f4
a1c7d58
abdf62b
33c23f4
 
abdf62b
 
33c23f4
 
5195c9e
abdf62b
 
 
 
 
5195c9e
 
 
 
 
 
 
33c23f4
 
 
 
 
 
 
 
 
 
5195c9e
abdf62b
33c23f4
abdf62b
33c23f4
 
 
 
 
 
 
 
 
 
 
 
 
 
206b5fc
5195c9e
206b5fc
33c23f4
abdf62b
5195c9e
33c23f4
 
 
 
 
abdf62b
33c23f4
abdf62b
5195c9e
33c23f4
5195c9e
 
abdf62b
206b5fc
 
33c23f4
5195c9e
abdf62b
206b5fc
 
166aa6c
abdf62b
5195c9e
33c23f4
 
 
 
5195c9e
33c23f4
 
 
 
5195c9e
 
abdf62b
33c23f4
abdf62b
5195c9e
 
 
 
33c23f4
 
 
5195c9e
 
206b5fc
abdf62b
206b5fc
abdf62b
206b5fc
5195c9e
 
 
 
 
 
 
 
 
 
1d21972
 
 
 
 
 
 
5195c9e
 
 
1d21972
 
 
 
abdf62b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import os
import json
import random
import argparse
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from huggingface_hub import upload_folder
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers.integrations import TensorBoardCallback
from transformers import (
    Wav2Vec2FeatureExtractor, HubertConfig, HubertForSequenceClassification,
    Trainer, TrainingArguments,
    EarlyStoppingCallback
)

MODEL = "ntu-spml/distilhubert" # modelo base
FEATURE_EXTRACTOR = Wav2Vec2FeatureExtractor.from_pretrained(MODEL) # feature extractor del modelo base
seed = 123
MAX_DURATION = 1.00 # Máxima duración de los audios
SAMPLING_RATE = FEATURE_EXTRACTOR.sampling_rate # 16kHz
token = os.getenv("HF_TOKEN") # TODO: probar a guardar el token en un archivo en local
config_file = "models_config.json"
batch_size = 1024 # TODO: repasar si sigue siendo necesario
num_workers = 12 # Núcleos de la CPU

class AudioDataset(Dataset):
    def __init__(self, dataset_path, label2id, filter_white_noise):
        self.dataset_path = dataset_path
        self.label2id = label2id
        self.filter_white_noise = filter_white_noise
        self.file_paths = []
        self.labels = []
        for label_dir, label_id in self.label2id.items():
            label_path = os.path.join(self.dataset_path, label_dir)
            if os.path.isdir(label_path):
                for file_name in os.listdir(label_path):
                    audio_path = os.path.join(label_path, file_name)
                    self.file_paths.append(audio_path)
                    self.labels.append(label_id)
        self.file_paths.sort(key=lambda x: x.split('_part')[0]) # no sé si influye

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        audio_path = self.file_paths[idx]
        label = self.labels[idx]
        input_values = self.preprocess_audio(audio_path)
        return {
            "input_values": input_values,
            "labels": torch.tensor(label)
        }

    def preprocess_audio(self, audio_path):
        waveform, sample_rate = torchaudio.load(
            audio_path,
            normalize=True, # Convierte a float32
            )
        if sample_rate != SAMPLING_RATE: # Resamplear si no es 16kHz
            resampler = torchaudio.transforms.Resample(sample_rate, SAMPLING_RATE)
            waveform = resampler(waveform)
        if waveform.shape[0] > 1: # Si es stereo, convertir a mono
            waveform = waveform.mean(dim=0, keepdim=True)
        waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-6) # Normalizar, sin 1e-6 el accuracy es pésimo!!
        max_length = int(SAMPLING_RATE * MAX_DURATION)
        if waveform.shape[1] > max_length:
            waveform = waveform[:, :max_length] # Truncar
        else:
            waveform = torch.nn.functional.pad(waveform, (0, max_length - waveform.shape[1])) # Padding
        inputs = FEATURE_EXTRACTOR(
            waveform.squeeze(),
            sampling_rate=SAMPLING_RATE, # Hecho a mano, por si acaso
            return_tensors="pt",
            # max_length=int(SAMPLING_RATE * MAX_DURATION),
            # truncation=True, # Hecho a mano
            # padding=True, # Hecho a mano
        )
        return inputs.input_values.squeeze()

def is_white_noise(audio):
    mean = torch.mean(audio)
    std = torch.std(audio)
    return torch.abs(mean) < 0.001 and std < 0.01

def seed_everything():
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16384:8'

def build_label_mappings(dataset_path):
    label2id = {}
    id2label = {}
    label_id = 0
    for label_dir in os.listdir(dataset_path):
        if os.path.isdir(os.path.join(dataset_path, label_dir)):
            label2id[label_dir] = label_id
            id2label[label_id] = label_dir
            label_id += 1
    return label2id, id2label

def create_dataloader(dataset_path, filter_white_noise, test_size=0.2, shuffle=True, pin_memory=True):
    label2id, id2label = build_label_mappings(dataset_path)
    dataset = AudioDataset(dataset_path, label2id, filter_white_noise)
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    random.shuffle(indices)
    split_idx = int(dataset_size * (1 - test_size))
    train_indices = indices[:split_idx]
    test_indices = indices[split_idx:]
    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    test_dataset = torch.utils.data.Subset(dataset, test_indices)
    train_dataloader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory
    )
    test_dataloader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory
    )
    return train_dataloader, test_dataloader, id2label

def load_model(model_path, id2label, num_labels):
    config = HubertConfig.from_pretrained(
        pretrained_model_name_or_path=model_path,
        num_labels=num_labels,
        id2label=id2label,
        finetuning_task="audio-classification"
    )
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = HubertForSequenceClassification.from_pretrained( # TODO: mirar parámetros. Posibles optimizaciones
        pretrained_model_name_or_path=model_path,
        config=config,
        torch_dtype=torch.float32,
    )
    model.to(device)
    return model

def train_params(dataset_path, filter_white_noise):
    train_dataloader, test_dataloader, id2label = create_dataloader(dataset_path, filter_white_noise)
    model = load_model(MODEL, id2label, num_labels=len(id2label))    
    return model, train_dataloader, test_dataloader, id2label

def predict_params(dataset_path, model_path, filter_white_noise):
    _, _, id2label = create_dataloader(dataset_path, filter_white_noise)
    model = load_model(model_path, id2label, num_labels=len(id2label))
    return model, id2label

def compute_metrics(eval_pred):
    predictions = torch.argmax(torch.tensor(eval_pred.predictions), dim=-1)
    references = torch.tensor(eval_pred.label_ids)
    accuracy = accuracy_score(references, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(references, predictions, average='weighted')
    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }

def main(training_args, output_dir, dataset_path, filter_white_noise):
    seed_everything()
    model, train_dataloader, test_dataloader, _ = train_params(dataset_path, filter_white_noise)
    trainer = Trainer(
        model=model,
        args=training_args,
        compute_metrics=compute_metrics,
        train_dataset=train_dataloader.dataset,
        eval_dataset=test_dataloader.dataset,
        callbacks=[TensorBoardCallback(), EarlyStoppingCallback(early_stopping_patience=3)]
    )
    torch.cuda.empty_cache() # liberar memoria de la GPU
    trainer.train() # resume_from_checkpoint para continuar el train
    # trainer.save_model(output_dir) # Guardar modelo local.
    os.makedirs(output_dir, exist_ok=True)
    trainer.push_to_hub(token=token) # Subir modelo a perfil
    upload_folder(repo_id=output_dir, folder_path=output_dir, token=token) # subir a organización y local

def load_config(model_name):
    with open(config_file, 'r') as f:
        config = json.load(f)
    model_config = config[model_name]
    training_args = TrainingArguments(**model_config["training_args"])
    model_config["training_args"] = training_args
    return model_config

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--n", choices=["mon", "class"], 
        required=True, help="Elegir qué modelo entrenar"
        )
    args = parser.parse_args()
    config = load_config(args.n)
    training_args = config["training_args"]
    output_dir = config["output_dir"]
    dataset_path = config["dataset_path"]
    if args.n == "mon":
        filter_white_noise = False
    elif args.n == "class":
        filter_white_noise = True
    main(training_args, output_dir, dataset_path, filter_white_noise)