Spaces:
Sleeping
Sleeping
Marcos12886
commited on
Update model.py
Browse files
model.py
CHANGED
@@ -6,6 +6,7 @@ import torchaudio
|
|
6 |
from torch.utils.data import Dataset, DataLoader
|
7 |
from huggingface_hub import login, upload_folder
|
8 |
from transformers.integrations import TensorBoardCallback
|
|
|
9 |
from transformers import (
|
10 |
Wav2Vec2FeatureExtractor, HubertConfig, HubertForSequenceClassification,
|
11 |
Trainer, TrainingArguments,
|
@@ -17,11 +18,11 @@ FEATURE_EXTRACTOR = Wav2Vec2FeatureExtractor.from_pretrained(MODEL)
|
|
17 |
seed = 123
|
18 |
MAX_DURATION = 1.00
|
19 |
SAMPLING_RATE = FEATURE_EXTRACTOR.sampling_rate # 16000
|
20 |
-
token = os.getenv(
|
21 |
config_file = "models_config.json"
|
22 |
clasificador = "class"
|
23 |
monitor = "mon"
|
24 |
-
batch_size =
|
25 |
|
26 |
class AudioDataset(Dataset):
|
27 |
def __init__(self, dataset_path, label2id):
|
@@ -60,13 +61,19 @@ class AudioDataset(Dataset):
|
|
60 |
waveform = resampler(waveform)
|
61 |
if waveform.shape[0] > 1: # Si es stereo, convertir a mono
|
62 |
waveform = waveform.mean(dim=0)
|
63 |
-
waveform = waveform / torch.max(torch.abs(waveform))
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
inputs = FEATURE_EXTRACTOR(
|
65 |
waveform,
|
66 |
sampling_rate=SAMPLING_RATE,
|
67 |
return_tensors="pt",
|
68 |
-
max_length=int(SAMPLING_RATE * MAX_DURATION),
|
69 |
-
truncation=True,
|
70 |
padding=True,
|
71 |
)
|
72 |
return inputs.input_values.squeeze()
|
@@ -131,10 +138,15 @@ def model_params(dataset_path):
|
|
131 |
return model, train_dataloader, test_dataloader, id2label
|
132 |
|
133 |
def compute_metrics(eval_pred):
|
134 |
-
predictions = torch.argmax(
|
135 |
-
references = eval_pred.label_ids
|
|
|
|
|
136 |
return {
|
137 |
-
"accuracy":
|
|
|
|
|
|
|
138 |
}
|
139 |
|
140 |
def main(training_args, output_dir, dataset_path):
|
|
|
6 |
from torch.utils.data import Dataset, DataLoader
|
7 |
from huggingface_hub import login, upload_folder
|
8 |
from transformers.integrations import TensorBoardCallback
|
9 |
+
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
10 |
from transformers import (
|
11 |
Wav2Vec2FeatureExtractor, HubertConfig, HubertForSequenceClassification,
|
12 |
Trainer, TrainingArguments,
|
|
|
18 |
seed = 123
|
19 |
MAX_DURATION = 1.00
|
20 |
SAMPLING_RATE = FEATURE_EXTRACTOR.sampling_rate # 16000
|
21 |
+
token = os.getenv("HF_TOKEN")
|
22 |
config_file = "models_config.json"
|
23 |
clasificador = "class"
|
24 |
monitor = "mon"
|
25 |
+
batch_size = 16
|
26 |
|
27 |
class AudioDataset(Dataset):
|
28 |
def __init__(self, dataset_path, label2id):
|
|
|
61 |
waveform = resampler(waveform)
|
62 |
if waveform.shape[0] > 1: # Si es stereo, convertir a mono
|
63 |
waveform = waveform.mean(dim=0)
|
64 |
+
waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-6) # Sin 1e-6 el accuracy es pésimo!!
|
65 |
+
max_length = int(SAMPLING_RATE * MAX_DURATION)
|
66 |
+
if waveform.shape[0] > max_length:
|
67 |
+
waveform = waveform[:max_length]
|
68 |
+
else:
|
69 |
+
# Pad the waveform if it's shorter than max length
|
70 |
+
waveform = torch.nn.functional.pad(waveform, (0, max_length - waveform.shape[0]))
|
71 |
inputs = FEATURE_EXTRACTOR(
|
72 |
waveform,
|
73 |
sampling_rate=SAMPLING_RATE,
|
74 |
return_tensors="pt",
|
75 |
+
# max_length=int(SAMPLING_RATE * MAX_DURATION),
|
76 |
+
# truncation=True,
|
77 |
padding=True,
|
78 |
)
|
79 |
return inputs.input_values.squeeze()
|
|
|
138 |
return model, train_dataloader, test_dataloader, id2label
|
139 |
|
140 |
def compute_metrics(eval_pred):
|
141 |
+
predictions = torch.argmax(torch.tensor(eval_pred.predictions), dim=-1)
|
142 |
+
references = torch.tensor(eval_pred.label_ids)
|
143 |
+
accuracy = accuracy_score(references, predictions)
|
144 |
+
precision, recall, f1, _ = precision_recall_fscore_support(references, predictions, average='weighted')
|
145 |
return {
|
146 |
+
"accuracy": accuracy,
|
147 |
+
"precision": precision,
|
148 |
+
"recall": recall,
|
149 |
+
"f1": f1,
|
150 |
}
|
151 |
|
152 |
def main(training_args, output_dir, dataset_path):
|