Marcos12886 commited on
Commit
81672a0
·
verified ·
1 Parent(s): 773b97e

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +20 -8
model.py CHANGED
@@ -6,6 +6,7 @@ import torchaudio
6
  from torch.utils.data import Dataset, DataLoader
7
  from huggingface_hub import login, upload_folder
8
  from transformers.integrations import TensorBoardCallback
 
9
  from transformers import (
10
  Wav2Vec2FeatureExtractor, HubertConfig, HubertForSequenceClassification,
11
  Trainer, TrainingArguments,
@@ -17,11 +18,11 @@ FEATURE_EXTRACTOR = Wav2Vec2FeatureExtractor.from_pretrained(MODEL)
17
  seed = 123
18
  MAX_DURATION = 1.00
19
  SAMPLING_RATE = FEATURE_EXTRACTOR.sampling_rate # 16000
20
- token = os.getenv('MODEL_REPO_ID')
21
  config_file = "models_config.json"
22
  clasificador = "class"
23
  monitor = "mon"
24
- batch_size = 4096
25
 
26
  class AudioDataset(Dataset):
27
  def __init__(self, dataset_path, label2id):
@@ -60,13 +61,19 @@ class AudioDataset(Dataset):
60
  waveform = resampler(waveform)
61
  if waveform.shape[0] > 1: # Si es stereo, convertir a mono
62
  waveform = waveform.mean(dim=0)
63
- waveform = waveform / torch.max(torch.abs(waveform))
 
 
 
 
 
 
64
  inputs = FEATURE_EXTRACTOR(
65
  waveform,
66
  sampling_rate=SAMPLING_RATE,
67
  return_tensors="pt",
68
- max_length=int(SAMPLING_RATE * MAX_DURATION),
69
- truncation=True,
70
  padding=True,
71
  )
72
  return inputs.input_values.squeeze()
@@ -131,10 +138,15 @@ def model_params(dataset_path):
131
  return model, train_dataloader, test_dataloader, id2label
132
 
133
  def compute_metrics(eval_pred):
134
- predictions = torch.argmax(input=eval_pred.predictions)
135
- references = eval_pred.label_ids
 
 
136
  return {
137
- "accuracy": torch.mean(predictions == references),
 
 
 
138
  }
139
 
140
  def main(training_args, output_dir, dataset_path):
 
6
  from torch.utils.data import Dataset, DataLoader
7
  from huggingface_hub import login, upload_folder
8
  from transformers.integrations import TensorBoardCallback
9
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
10
  from transformers import (
11
  Wav2Vec2FeatureExtractor, HubertConfig, HubertForSequenceClassification,
12
  Trainer, TrainingArguments,
 
18
  seed = 123
19
  MAX_DURATION = 1.00
20
  SAMPLING_RATE = FEATURE_EXTRACTOR.sampling_rate # 16000
21
+ token = os.getenv("HF_TOKEN")
22
  config_file = "models_config.json"
23
  clasificador = "class"
24
  monitor = "mon"
25
+ batch_size = 16
26
 
27
  class AudioDataset(Dataset):
28
  def __init__(self, dataset_path, label2id):
 
61
  waveform = resampler(waveform)
62
  if waveform.shape[0] > 1: # Si es stereo, convertir a mono
63
  waveform = waveform.mean(dim=0)
64
+ waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-6) # Sin 1e-6 el accuracy es pésimo!!
65
+ max_length = int(SAMPLING_RATE * MAX_DURATION)
66
+ if waveform.shape[0] > max_length:
67
+ waveform = waveform[:max_length]
68
+ else:
69
+ # Pad the waveform if it's shorter than max length
70
+ waveform = torch.nn.functional.pad(waveform, (0, max_length - waveform.shape[0]))
71
  inputs = FEATURE_EXTRACTOR(
72
  waveform,
73
  sampling_rate=SAMPLING_RATE,
74
  return_tensors="pt",
75
+ # max_length=int(SAMPLING_RATE * MAX_DURATION),
76
+ # truncation=True,
77
  padding=True,
78
  )
79
  return inputs.input_values.squeeze()
 
138
  return model, train_dataloader, test_dataloader, id2label
139
 
140
  def compute_metrics(eval_pred):
141
+ predictions = torch.argmax(torch.tensor(eval_pred.predictions), dim=-1)
142
+ references = torch.tensor(eval_pred.label_ids)
143
+ accuracy = accuracy_score(references, predictions)
144
+ precision, recall, f1, _ = precision_recall_fscore_support(references, predictions, average='weighted')
145
  return {
146
+ "accuracy": accuracy,
147
+ "precision": precision,
148
+ "recall": recall,
149
+ "f1": f1,
150
  }
151
 
152
  def main(training_args, output_dir, dataset_path):