alefiury Salman11223 commited on
Commit
c0fab16
·
1 Parent(s): 86366c6

Update README.md (#9)

Browse files

- Update README.md (fc858efea9fac4aed1b0122c9e039b65a34c31b8)


Co-authored-by: Salman Zafar <[email protected]>

Files changed (1) hide show
  1. README.md +20 -14
README.md CHANGED
@@ -78,32 +78,37 @@ class CustomDataset(torch.utils.data.Dataset):
78
  return audio
79
 
80
 
81
- def __getitem__(self, index) -> torch.Tensor:
82
- """
83
- Return the audio and the sampling rate
84
- """
85
  if self.basedir is None:
86
  filepath = self.dataset[index]
87
  else:
88
  filepath = os.path.join(self.basedir, self.dataset[index])
89
-
90
  speech_array, sr = torchaudio.load(filepath)
91
-
92
- # Transform to mono
93
  if speech_array.shape[0] > 1:
94
  speech_array = torch.mean(speech_array, dim=0, keepdim=True)
95
-
96
  if sr != self.sampling_rate:
97
  transform = torchaudio.transforms.Resample(sr, self.sampling_rate)
98
  speech_array = transform(speech_array)
99
  sr = self.sampling_rate
100
-
 
 
 
 
 
 
 
 
 
 
 
101
  speech_array = speech_array.squeeze().numpy()
 
 
102
 
103
- # Cut or pad audio
104
- speech_array = self._cutorpad(speech_array)
105
-
106
- return speech_array
107
 
108
  class CollateFunc:
109
  def __init__(
@@ -172,7 +177,8 @@ def get_gender(model_name_or_path: str, audio_paths: List[str], label2id: Dict,
172
  id2label=id2label,
173
  )
174
 
175
- test_dataset = CustomDataset(audio_paths)
 
176
  data_collator = CollateFunc(
177
  processor=feature_extractor,
178
  padding=True,
 
78
  return audio
79
 
80
 
81
+ def __getitem__(self, index):
 
 
 
82
  if self.basedir is None:
83
  filepath = self.dataset[index]
84
  else:
85
  filepath = os.path.join(self.basedir, self.dataset[index])
86
+
87
  speech_array, sr = torchaudio.load(filepath)
88
+
 
89
  if speech_array.shape[0] > 1:
90
  speech_array = torch.mean(speech_array, dim=0, keepdim=True)
91
+
92
  if sr != self.sampling_rate:
93
  transform = torchaudio.transforms.Resample(sr, self.sampling_rate)
94
  speech_array = transform(speech_array)
95
  sr = self.sampling_rate
96
+
97
+ len_audio = speech_array.shape[1]
98
+
99
+ # Pad or truncate the audio to match the desired length
100
+ if len_audio < self.max_audio_len * self.sampling_rate:
101
+ # Pad the audio if it's shorter than the desired length
102
+ padding = torch.zeros(1, self.max_audio_len * self.sampling_rate - len_audio)
103
+ speech_array = torch.cat([speech_array, padding], dim=1)
104
+ else:
105
+ # Truncate the audio if it's longer than the desired length
106
+ speech_array = speech_array[:, :self.max_audio_len * self.sampling_rate]
107
+
108
  speech_array = speech_array.squeeze().numpy()
109
+
110
+ return {"input_values": speech_array, "attention_mask": None}
111
 
 
 
 
 
112
 
113
  class CollateFunc:
114
  def __init__(
 
177
  id2label=id2label,
178
  )
179
 
180
+ test_dataset = CustomDataset(audio_paths, max_audio_len=300) # for 5-minute audio
181
+
182
  data_collator = CollateFunc(
183
  processor=feature_extractor,
184
  padding=True,