Files changed (1) hide show
  1. README.md +20 -14
README.md CHANGED
@@ -77,32 +77,37 @@ class CustomDataset(torch.utils.data.Dataset):
77
  return audio
78
 
79
 
80
- def __getitem__(self, index) -> torch.Tensor:
81
- """
82
- Return the audio and the sampling rate
83
- """
84
  if self.basedir is None:
85
  filepath = self.dataset[index]
86
  else:
87
  filepath = os.path.join(self.basedir, self.dataset[index])
88
-
89
  speech_array, sr = torchaudio.load(filepath)
90
-
91
- # Transform to mono
92
  if speech_array.shape[0] > 1:
93
  speech_array = torch.mean(speech_array, dim=0, keepdim=True)
94
-
95
  if sr != self.sampling_rate:
96
  transform = torchaudio.transforms.Resample(sr, self.sampling_rate)
97
  speech_array = transform(speech_array)
98
  sr = self.sampling_rate
99
-
 
 
 
 
 
 
 
 
 
 
 
100
  speech_array = speech_array.squeeze().numpy()
 
 
101
 
102
- # Cut or pad audio
103
- speech_array = self._cutorpad(speech_array)
104
-
105
- return speech_array
106
 
107
  class CollateFunc:
108
  def __init__(
@@ -171,7 +176,8 @@ def get_gender(model_name_or_path: str, audio_paths: List[str], label2id: Dict,
171
  id2label=id2label,
172
  )
173
 
174
- test_dataset = CustomDataset(audio_paths)
 
175
  data_collator = CollateFunc(
176
  processor=feature_extractor,
177
  padding=True,
 
77
  return audio
78
 
79
 
80
+ def __getitem__(self, index):
 
 
 
81
  if self.basedir is None:
82
  filepath = self.dataset[index]
83
  else:
84
  filepath = os.path.join(self.basedir, self.dataset[index])
85
+
86
  speech_array, sr = torchaudio.load(filepath)
87
+
 
88
  if speech_array.shape[0] > 1:
89
  speech_array = torch.mean(speech_array, dim=0, keepdim=True)
90
+
91
  if sr != self.sampling_rate:
92
  transform = torchaudio.transforms.Resample(sr, self.sampling_rate)
93
  speech_array = transform(speech_array)
94
  sr = self.sampling_rate
95
+
96
+ len_audio = speech_array.shape[1]
97
+
98
+ # Pad or truncate the audio to match the desired length
99
+ if len_audio < self.max_audio_len * self.sampling_rate:
100
+ # Pad the audio if it's shorter than the desired length
101
+ padding = torch.zeros(1, self.max_audio_len * self.sampling_rate - len_audio)
102
+ speech_array = torch.cat([speech_array, padding], dim=1)
103
+ else:
104
+ # Truncate the audio if it's longer than the desired length
105
+ speech_array = speech_array[:, :self.max_audio_len * self.sampling_rate]
106
+
107
  speech_array = speech_array.squeeze().numpy()
108
+
109
+ return {"input_values": speech_array, "attention_mask": None}
110
 
 
 
 
 
111
 
112
  class CollateFunc:
113
  def __init__(
 
176
  id2label=id2label,
177
  )
178
 
179
+ test_dataset = CustomDataset(audio_paths, max_audio_len=300) # for 5-minute audio
180
+
181
  data_collator = CollateFunc(
182
  processor=feature_extractor,
183
  padding=True,