Commit
·
c0fab16
1
Parent(s):
86366c6
Update README.md (#9)
Browse files- Update README.md (fc858efea9fac4aed1b0122c9e039b65a34c31b8)
Co-authored-by: Salman Zafar <[email protected]>
README.md
CHANGED
@@ -78,32 +78,37 @@ class CustomDataset(torch.utils.data.Dataset):
|
|
78 |
return audio
|
79 |
|
80 |
|
81 |
-
def __getitem__(self, index)
|
82 |
-
"""
|
83 |
-
Return the audio and the sampling rate
|
84 |
-
"""
|
85 |
if self.basedir is None:
|
86 |
filepath = self.dataset[index]
|
87 |
else:
|
88 |
filepath = os.path.join(self.basedir, self.dataset[index])
|
89 |
-
|
90 |
speech_array, sr = torchaudio.load(filepath)
|
91 |
-
|
92 |
-
# Transform to mono
|
93 |
if speech_array.shape[0] > 1:
|
94 |
speech_array = torch.mean(speech_array, dim=0, keepdim=True)
|
95 |
-
|
96 |
if sr != self.sampling_rate:
|
97 |
transform = torchaudio.transforms.Resample(sr, self.sampling_rate)
|
98 |
speech_array = transform(speech_array)
|
99 |
sr = self.sampling_rate
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
speech_array = speech_array.squeeze().numpy()
|
|
|
|
|
102 |
|
103 |
-
# Cut or pad audio
|
104 |
-
speech_array = self._cutorpad(speech_array)
|
105 |
-
|
106 |
-
return speech_array
|
107 |
|
108 |
class CollateFunc:
|
109 |
def __init__(
|
@@ -172,7 +177,8 @@ def get_gender(model_name_or_path: str, audio_paths: List[str], label2id: Dict,
|
|
172 |
id2label=id2label,
|
173 |
)
|
174 |
|
175 |
-
test_dataset = CustomDataset(audio_paths)
|
|
|
176 |
data_collator = CollateFunc(
|
177 |
processor=feature_extractor,
|
178 |
padding=True,
|
|
|
78 |
return audio
|
79 |
|
80 |
|
81 |
+
def __getitem__(self, index):
|
|
|
|
|
|
|
82 |
if self.basedir is None:
|
83 |
filepath = self.dataset[index]
|
84 |
else:
|
85 |
filepath = os.path.join(self.basedir, self.dataset[index])
|
86 |
+
|
87 |
speech_array, sr = torchaudio.load(filepath)
|
88 |
+
|
|
|
89 |
if speech_array.shape[0] > 1:
|
90 |
speech_array = torch.mean(speech_array, dim=0, keepdim=True)
|
91 |
+
|
92 |
if sr != self.sampling_rate:
|
93 |
transform = torchaudio.transforms.Resample(sr, self.sampling_rate)
|
94 |
speech_array = transform(speech_array)
|
95 |
sr = self.sampling_rate
|
96 |
+
|
97 |
+
len_audio = speech_array.shape[1]
|
98 |
+
|
99 |
+
# Pad or truncate the audio to match the desired length
|
100 |
+
if len_audio < self.max_audio_len * self.sampling_rate:
|
101 |
+
# Pad the audio if it's shorter than the desired length
|
102 |
+
padding = torch.zeros(1, self.max_audio_len * self.sampling_rate - len_audio)
|
103 |
+
speech_array = torch.cat([speech_array, padding], dim=1)
|
104 |
+
else:
|
105 |
+
# Truncate the audio if it's longer than the desired length
|
106 |
+
speech_array = speech_array[:, :self.max_audio_len * self.sampling_rate]
|
107 |
+
|
108 |
speech_array = speech_array.squeeze().numpy()
|
109 |
+
|
110 |
+
return {"input_values": speech_array, "attention_mask": None}
|
111 |
|
|
|
|
|
|
|
|
|
112 |
|
113 |
class CollateFunc:
|
114 |
def __init__(
|
|
|
177 |
id2label=id2label,
|
178 |
)
|
179 |
|
180 |
+
test_dataset = CustomDataset(audio_paths, max_audio_len=300) # for 5-minute audio
|
181 |
+
|
182 |
data_collator = CollateFunc(
|
183 |
processor=feature_extractor,
|
184 |
padding=True,
|