CHATBOT / model.py
Marcos12886's picture
Usar label2id menos
206b5fc
raw
history blame
8.32 kB
import os
import json
import random
import argparse
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from huggingface_hub import upload_folder
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers.integrations import TensorBoardCallback
from transformers import (
Wav2Vec2FeatureExtractor, HubertConfig, HubertForSequenceClassification,
Trainer, TrainingArguments,
EarlyStoppingCallback
)
MODEL = "ntu-spml/distilhubert" # modelo base
FEATURE_EXTRACTOR = Wav2Vec2FeatureExtractor.from_pretrained(MODEL) # feature extractor del modelo base
seed = 123
MAX_DURATION = 1.00 # Máxima duración de los audios
SAMPLING_RATE = FEATURE_EXTRACTOR.sampling_rate # 16kHz
token = os.getenv("HF_TOKEN") # TODO: probar a guardar el token en un archivo en local
config_file = "models_config.json"
batch_size = 1024 # TODO: repasar si sigue siendo necesario
num_workers = 12 # Núcleos de la CPU
class AudioDataset(Dataset):
def __init__(self, dataset_path, label2id, filter_white_noise):
self.dataset_path = dataset_path
self.label2id = label2id
self.filter_white_noise = filter_white_noise
self.file_paths = []
self.labels = []
for label_dir, label_id in self.label2id.items():
label_path = os.path.join(self.dataset_path, label_dir)
if os.path.isdir(label_path):
for file_name in os.listdir(label_path):
audio_path = os.path.join(label_path, file_name)
self.file_paths.append(audio_path)
self.labels.append(label_id)
self.file_paths.sort(key=lambda x: x.split('_part')[0]) # no sé si influye
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
audio_path = self.file_paths[idx]
label = self.labels[idx]
input_values = self.preprocess_audio(audio_path)
return {
"input_values": input_values,
"labels": torch.tensor(label)
}
def preprocess_audio(self, audio_path):
waveform, sample_rate = torchaudio.load(
audio_path,
normalize=True, # Convierte a float32
)
if sample_rate != SAMPLING_RATE: # Resamplear si no es 16kHz
resampler = torchaudio.transforms.Resample(sample_rate, SAMPLING_RATE)
waveform = resampler(waveform)
if waveform.shape[0] > 1: # Si es stereo, convertir a mono
waveform = waveform.mean(dim=0, keepdim=True)
waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-6) # Normalizar, sin 1e-6 el accuracy es pésimo!!
max_length = int(SAMPLING_RATE * MAX_DURATION)
if waveform.shape[1] > max_length:
waveform = waveform[:, :max_length] # Truncar
else:
waveform = torch.nn.functional.pad(waveform, (0, max_length - waveform.shape[1])) # Padding
inputs = FEATURE_EXTRACTOR(
waveform.squeeze(),
sampling_rate=SAMPLING_RATE, # Hecho a mano, por si acaso
return_tensors="pt",
# max_length=int(SAMPLING_RATE * MAX_DURATION),
# truncation=True, # Hecho a mano
# padding=True, # Hecho a mano
)
return inputs.input_values.squeeze()
def is_white_noise(audio):
mean = torch.mean(audio)
std = torch.std(audio)
return torch.abs(mean) < 0.001 and std < 0.01
def seed_everything():
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16384:8'
def build_label_mappings(dataset_path):
label2id = {}
id2label = {}
label_id = 0
for label_dir in os.listdir(dataset_path):
if os.path.isdir(os.path.join(dataset_path, label_dir)):
label2id[label_dir] = label_id
id2label[label_id] = label_dir
label_id += 1
return label2id, id2label
def create_dataloader(dataset_path, filter_white_noise, test_size=0.2, shuffle=True, pin_memory=True):
label2id, id2label = build_label_mappings(dataset_path)
dataset = AudioDataset(dataset_path, label2id, filter_white_noise)
dataset_size = len(dataset)
indices = list(range(dataset_size))
random.shuffle(indices)
split_idx = int(dataset_size * (1 - test_size))
train_indices = indices[:split_idx]
test_indices = indices[split_idx:]
train_dataset = torch.utils.data.Subset(dataset, train_indices)
test_dataset = torch.utils.data.Subset(dataset, test_indices)
train_dataloader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory
)
test_dataloader = DataLoader(
test_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory
)
return train_dataloader, test_dataloader, id2label
def load_model(model_path, id2label, num_labels):
config = HubertConfig.from_pretrained(
pretrained_model_name_or_path=model_path,
num_labels=num_labels,
id2label=id2label,
finetuning_task="audio-classification"
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HubertForSequenceClassification.from_pretrained( # TODO: mirar parámetros. Posibles optimizaciones
pretrained_model_name_or_path=model_path,
config=config,
torch_dtype=torch.float32,
)
model.to(device)
return model
def train_params(dataset_path, filter_white_noise):
train_dataloader, test_dataloader, id2label = create_dataloader(dataset_path, filter_white_noise)
model = load_model(MODEL, id2label, num_labels=len(id2label))
return model, train_dataloader, test_dataloader, id2label
def predict_params(dataset_path, model_path, filter_white_noise):
_, _, id2label = create_dataloader(dataset_path, filter_white_noise)
model = load_model(model_path, id2label, num_labels=len(id2label))
return model, id2label
def compute_metrics(eval_pred):
predictions = torch.argmax(torch.tensor(eval_pred.predictions), dim=-1)
references = torch.tensor(eval_pred.label_ids)
accuracy = accuracy_score(references, predictions)
precision, recall, f1, _ = precision_recall_fscore_support(references, predictions, average='weighted')
return {
"accuracy": accuracy,
"precision": precision,
"recall": recall,
"f1": f1,
}
def main(training_args, output_dir, dataset_path, filter_white_noise):
seed_everything()
model, train_dataloader, test_dataloader, _ = train_params(dataset_path, filter_white_noise)
trainer = Trainer(
model=model,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=train_dataloader.dataset,
eval_dataset=test_dataloader.dataset,
callbacks=[TensorBoardCallback(), EarlyStoppingCallback(early_stopping_patience=3)]
)
torch.cuda.empty_cache() # liberar memoria de la GPU
trainer.train() # resume_from_checkpoint para continuar el train
# trainer.save_model(output_dir) # Guardar modelo local.
os.makedirs(output_dir, exist_ok=True)
trainer.push_to_hub(token=token) # Subir modelo a perfil
upload_folder(repo_id=output_dir, folder_path=output_dir, token=token) # subir a organización y local
def load_config(model_name):
with open(config_file, 'r') as f:
config = json.load(f)
model_config = config[model_name]
training_args = TrainingArguments(**model_config["training_args"])
model_config["training_args"] = training_args
return model_config
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--n", choices=["mon", "class"],
required=True, help="Elegir qué modelo entrenar"
)
args = parser.parse_args()
config = load_config(args.n)
training_args = config["training_args"]
output_dir = config["output_dir"]
dataset_path = config["dataset_path"]
if args.n == "mon":
filter_white_noise = False
elif args.n == "class":
filter_white_noise = True
main(training_args, output_dir, dataset_path, filter_white_noise)