File size: 11,593 Bytes
3a98934
33c23f4
 
1d21972
33c23f4
 
53f6532
a1c7d58
53f6532
 
166aa6c
5195c9e
33c23f4
5195c9e
 
 
33c23f4
206b5fc
 
5195c9e
206b5fc
 
53f6532
5195c9e
206b5fc
 
33c23f4
 
53f6532
33c23f4
 
 
763091b
33c23f4
 
 
 
 
 
 
 
763091b
53f6532
 
 
 
 
 
763091b
 
 
 
 
 
 
 
 
 
 
 
33c23f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53f6532
33c23f4
 
 
 
 
a1c7d58
53f6532
33c23f4
a1c7d58
abdf62b
33c23f4
abdf62b
33c23f4
a1c7d58
abdf62b
33c23f4
 
 
53f6532
abdf62b
 
 
 
 
763091b
5195c9e
 
763091b
 
5195c9e
33c23f4
 
 
 
 
 
 
 
 
 
5195c9e
53f6532
 
 
 
 
 
 
33c23f4
53f6532
33c23f4
 
 
 
 
 
53f6532
 
 
 
 
 
 
 
 
33c23f4
53f6532
33c23f4
 
 
 
206b5fc
5195c9e
206b5fc
33c23f4
abdf62b
5195c9e
33c23f4
 
 
 
763091b
abdf62b
33c23f4
763091b
5195c9e
33c23f4
5195c9e
 
53f6532
 
206b5fc
33c23f4
5195c9e
53f6532
 
206b5fc
166aa6c
abdf62b
53f6532
 
 
 
 
 
5195c9e
53f6532
 
 
 
 
 
5195c9e
53f6532
33c23f4
53f6532
 
 
 
 
5195c9e
 
 
 
33c23f4
 
53f6532
5195c9e
 
206b5fc
abdf62b
206b5fc
53f6532
 
 
abdf62b
53f6532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5195c9e
 
 
 
 
 
 
 
 
 
1d21972
 
 
 
 
 
 
5195c9e
 
 
1d21972
 
53f6532
1d21972
 
53f6532
 
b8ef56d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
import os
import json
import random
import argparse
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, Subset
from huggingface_hub import upload_folder
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
from collections import Counter
from transformers.integrations import TensorBoardCallback
from transformers import (
    Wav2Vec2FeatureExtractor, HubertConfig, HubertForSequenceClassification,
    Trainer, TrainingArguments,
    EarlyStoppingCallback
)

MODEL = "ntu-spml/distilhubert" # modelo base
FEATURE_EXTRACTOR = Wav2Vec2FeatureExtractor.from_pretrained(MODEL) # feature extractor del modelo base
seed = 123
MAX_DURATION = 1.00 # Máxima duración de los audios
SAMPLING_RATE = FEATURE_EXTRACTOR.sampling_rate # 16kHz
token = os.getenv("HF_TOKEN")
config_file = "models_config.json"
batch_size = 1024 # TODO: repasar si sigue siendo necesario
num_workers = 12 # Núcleos de la CPU

class AudioDataset(Dataset):
    def __init__(self, dataset_path, label2id, filter_white_noise, undersample_normal):
        self.dataset_path = dataset_path
        self.label2id = label2id
        self.file_paths = []
        self.filter_white_noise = filter_white_noise
        self.labels = []
        for label_dir, label_id in self.label2id.items():
            label_path = os.path.join(self.dataset_path, label_dir)
            if os.path.isdir(label_path):
                for file_name in os.listdir(label_path):
                    audio_path = os.path.join(label_path, file_name)
                    self.file_paths.append(audio_path)
                    self.labels.append(label_id)
        if undersample_normal and self.label2id:
            self.undersample_normal_class()

    def undersample_normal_class(self):
        normal_label = self.label2id.get('1s_normal')
        label_counts = Counter(self.labels)
        other_counts = [count for label, count in label_counts.items() if label != normal_label]
        if other_counts:  # Ensure there are other counts before taking max
            target_count = max(other_counts)
            normal_indices = [i for i, label in enumerate(self.labels) if label == normal_label]
            keep_indices = random.sample(normal_indices, target_count)
            new_file_paths = []
            new_labels = []
            for i, (path, label) in enumerate(zip(self.file_paths, self.labels)):
                if label != normal_label or i in keep_indices:
                    new_file_paths.append(path)
                    new_labels.append(label)
            self.file_paths = new_file_paths
            self.labels = new_labels

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        audio_path = self.file_paths[idx]
        label = self.labels[idx]
        input_values = self.preprocess_audio(audio_path)
        return {
            "input_values": input_values,
            "labels": torch.tensor(label)
        }

    def preprocess_audio(self, audio_path):
        waveform, sample_rate = torchaudio.load(
            audio_path,
            normalize=True,
            )
        if sample_rate != SAMPLING_RATE: # Resamplear si no es 16kHz
            resampler = torchaudio.transforms.Resample(sample_rate, SAMPLING_RATE)
            waveform = resampler(waveform)
        if waveform.shape[0] > 1: # Si es stereo, convertir a mono
            waveform = waveform.mean(dim=0, keepdim=True)
        waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-6) # TODO: probar a quitar porque ya se hace, sin 1e-6 el accuracy es pésimo!!
        max_length = int(SAMPLING_RATE * MAX_DURATION)
        if waveform.shape[1] > max_length:
            waveform = waveform[:, :max_length] # Truncar
        else:
            waveform = torch.nn.functional.pad(waveform, (0, max_length - waveform.shape[1])) # Padding
        inputs = FEATURE_EXTRACTOR(
            waveform.squeeze(),
            sampling_rate=SAMPLING_RATE, # Hecho a mano, por si acaso
            return_tensors="pt",
        )
        return inputs.input_values.squeeze()
    
def is_white_noise(audio):
    mean = torch.mean(audio)
    std = torch.std(audio)
    return torch.abs(mean) < 0.001 and std < 0.01

def seed_everything(): # TODO: mirar si es necesario algo más 
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # torch.backends.cudnn.deterministic = True # Para reproducibilidad
    # torch.backends.cudnn.benchmark = False # Para reproducibilidad

def build_label_mappings(dataset_path):
    label2id = {}
    id2label = {}
    label_id = 0
    for label_dir in os.listdir(dataset_path):
        if os.path.isdir(os.path.join(dataset_path, label_dir)):
            label2id[label_dir] = label_id
            id2label[label_id] = label_dir
            label_id += 1
    return label2id, id2label

def compute_class_weights(labels):
    class_counts = Counter(labels)
    total_samples = len(labels)
    class_weights = {cls: total_samples / count for cls, count in class_counts.items()}
    return [class_weights[label] for label in labels]

def create_dataloader(dataset_path, filter_white_noise, undersample_normal, test_size=0.2, shuffle=True, pin_memory=True):
    label2id, id2label = build_label_mappings(dataset_path)
    dataset = AudioDataset(dataset_path, label2id, filter_white_noise, undersample_normal)
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    random.shuffle(indices)
    split_idx = int(dataset_size * (1 - test_size))
    train_indices = indices[:split_idx]
    test_indices = indices[split_idx:]
    train_dataset = Subset(dataset, train_indices)
    test_dataset = Subset(dataset, test_indices)
    labels = [dataset.labels[i] for i in train_indices]
    class_weights = compute_class_weights(labels)
    sampler = WeightedRandomSampler(
        weights=class_weights,
        num_samples=len(train_dataset),
        replacement=True
    )
    train_dataloader = DataLoader(
        train_dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers, pin_memory=pin_memory
    )
    test_dataloader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory
    )
    return train_dataloader, test_dataloader, id2label

def load_model(model_path, id2label, num_labels):
    config = HubertConfig.from_pretrained(
        pretrained_model_name_or_path=model_path,
        num_labels=num_labels,
        id2label=id2label,
        finetuning_task="audio-classification"
    )
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = HubertForSequenceClassification.from_pretrained(
        pretrained_model_name_or_path=model_path,
        config=config,
        torch_dtype=torch.float32, # TODO: Comprobar si se necesita float32 y ver si se puede cambiar por float16
    )
    model.to(device)
    return model

def train_params(dataset_path, filter_white_noise, undersample_normal):
    train_dataloader, test_dataloader, id2label = create_dataloader(dataset_path, filter_white_noise, undersample_normal)
    model = load_model(MODEL, id2label, num_labels=len(id2label))    
    return model, train_dataloader, test_dataloader, id2label

def predict_params(dataset_path, model_path, filter_white_noise, undersample_normal):
    _, _, id2label = create_dataloader(dataset_path, filter_white_noise, undersample_normal)
    model = load_model(model_path, id2label, num_labels=len(id2label))
    return model, id2label

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    acc = accuracy_score(labels, preds)
    cm = confusion_matrix(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall,
        'confusion_matrix': cm.tolist()
        }

def main(training_args, output_dir, dataset_path, filter_white_noise, undersample_normal):
    seed_everything()
    model, train_dataloader, test_dataloader, id2label = train_params(dataset_path, filter_white_noise, undersample_normal)
    early_stopping_callback = EarlyStoppingCallback(
        early_stopping_patience=5,
        early_stopping_threshold=0.001
        )
    trainer = Trainer(
        model=model,
        args=training_args,
        compute_metrics=compute_metrics,
        train_dataset=train_dataloader.dataset,
        eval_dataset=test_dataloader.dataset,
        callbacks=[TensorBoardCallback, early_stopping_callback]
    )
    torch.cuda.empty_cache() # liberar memoria de la GPU
    trainer.train() # resume_from_checkpoint para continuar el train
    # trainer.save_model(output_dir) # Guardar modelo local.
    os.makedirs(output_dir, exist_ok=True)
    trainer.save_model(output_dir) # Guardar modelo local.
    eval_results = trainer.evaluate()
    print(f"Evaluation results: {eval_results}")
    trainer.push_to_hub(token=token) # Subir modelo a perfil
    upload_folder(repo_id=f"A-POR-LOS-8000/{output_dir}", folder_path=output_dir, token=token) # subir a organización y local
    
    def predict(audio_path):
        waveform, sample_rate = torchaudio.load(audio_path, normalize=True)
        if sample_rate != SAMPLING_RATE:
            resampler = torchaudio.transforms.Resample(sample_rate, SAMPLING_RATE)
            waveform = resampler(waveform)
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)
        waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-6)
        max_length = int(SAMPLING_RATE * MAX_DURATION)
        if waveform.shape[1] > max_length:
            waveform = waveform[:, :max_length]
        else:
            waveform = torch.nn.functional.pad(waveform, (0, max_length - waveform.shape[1]))
        inputs = FEATURE_EXTRACTOR(
            waveform.squeeze(),
            sampling_rate=SAMPLING_RATE,
            return_tensors="pt",
        )
        with torch.no_grad():
            logits = model(inputs.input_values.to(model.device)).logits
            predicted_class_id = logits.argmax().item()
            predicted_label = id2label[predicted_class_id]
        return predicted_label, logits
    test_samples = random.sample(test_dataloader.dataset.dataset.file_paths, 15)
    for sample in test_samples:
        predicted_label, logits = predict(sample)
        print(f"File: {sample}")
        print(f"Predicted label: {predicted_label}")
        print(f"Logits: {logits}")
        print("---")

def load_config(model_name):
    with open(config_file, 'r') as f:
        config = json.load(f)
    model_config = config[model_name]
    training_args = TrainingArguments(**model_config["training_args"])
    model_config["training_args"] = training_args
    return model_config

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--n", choices=["mon", "class"], 
        required=True, help="Elegir qué modelo entrenar"
        )
    args = parser.parse_args()
    config = load_config(args.n)
    training_args = config["training_args"]
    output_dir = config["output_dir"]
    dataset_path = config["dataset_path"]
    if args.n == "mon":
        filter_white_noise = False
        undersample_normal = False
    elif args.n == "class":
        filter_white_noise = True
        undersample_normal = True
    main(training_args, output_dir, dataset_path, filter_white_noise, undersample_normal)