Update README.md
#9
by
Salman11223
- opened
README.md
CHANGED
@@ -77,32 +77,37 @@ class CustomDataset(torch.utils.data.Dataset):
|
|
77 |
return audio
|
78 |
|
79 |
|
80 |
-
def __getitem__(self, index)
|
81 |
-
"""
|
82 |
-
Return the audio and the sampling rate
|
83 |
-
"""
|
84 |
if self.basedir is None:
|
85 |
filepath = self.dataset[index]
|
86 |
else:
|
87 |
filepath = os.path.join(self.basedir, self.dataset[index])
|
88 |
-
|
89 |
speech_array, sr = torchaudio.load(filepath)
|
90 |
-
|
91 |
-
# Transform to mono
|
92 |
if speech_array.shape[0] > 1:
|
93 |
speech_array = torch.mean(speech_array, dim=0, keepdim=True)
|
94 |
-
|
95 |
if sr != self.sampling_rate:
|
96 |
transform = torchaudio.transforms.Resample(sr, self.sampling_rate)
|
97 |
speech_array = transform(speech_array)
|
98 |
sr = self.sampling_rate
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
speech_array = speech_array.squeeze().numpy()
|
|
|
|
|
101 |
|
102 |
-
# Cut or pad audio
|
103 |
-
speech_array = self._cutorpad(speech_array)
|
104 |
-
|
105 |
-
return speech_array
|
106 |
|
107 |
class CollateFunc:
|
108 |
def __init__(
|
@@ -171,7 +176,8 @@ def get_gender(model_name_or_path: str, audio_paths: List[str], label2id: Dict,
|
|
171 |
id2label=id2label,
|
172 |
)
|
173 |
|
174 |
-
test_dataset = CustomDataset(audio_paths)
|
|
|
175 |
data_collator = CollateFunc(
|
176 |
processor=feature_extractor,
|
177 |
padding=True,
|
|
|
77 |
return audio
|
78 |
|
79 |
|
80 |
+
def __getitem__(self, index):
|
|
|
|
|
|
|
81 |
if self.basedir is None:
|
82 |
filepath = self.dataset[index]
|
83 |
else:
|
84 |
filepath = os.path.join(self.basedir, self.dataset[index])
|
85 |
+
|
86 |
speech_array, sr = torchaudio.load(filepath)
|
87 |
+
|
|
|
88 |
if speech_array.shape[0] > 1:
|
89 |
speech_array = torch.mean(speech_array, dim=0, keepdim=True)
|
90 |
+
|
91 |
if sr != self.sampling_rate:
|
92 |
transform = torchaudio.transforms.Resample(sr, self.sampling_rate)
|
93 |
speech_array = transform(speech_array)
|
94 |
sr = self.sampling_rate
|
95 |
+
|
96 |
+
len_audio = speech_array.shape[1]
|
97 |
+
|
98 |
+
# Pad or truncate the audio to match the desired length
|
99 |
+
if len_audio < self.max_audio_len * self.sampling_rate:
|
100 |
+
# Pad the audio if it's shorter than the desired length
|
101 |
+
padding = torch.zeros(1, self.max_audio_len * self.sampling_rate - len_audio)
|
102 |
+
speech_array = torch.cat([speech_array, padding], dim=1)
|
103 |
+
else:
|
104 |
+
# Truncate the audio if it's longer than the desired length
|
105 |
+
speech_array = speech_array[:, :self.max_audio_len * self.sampling_rate]
|
106 |
+
|
107 |
speech_array = speech_array.squeeze().numpy()
|
108 |
+
|
109 |
+
return {"input_values": speech_array, "attention_mask": None}
|
110 |
|
|
|
|
|
|
|
|
|
111 |
|
112 |
class CollateFunc:
|
113 |
def __init__(
|
|
|
176 |
id2label=id2label,
|
177 |
)
|
178 |
|
179 |
+
test_dataset = CustomDataset(audio_paths, max_audio_len=300) # for 5-minute audio
|
180 |
+
|
181 |
data_collator = CollateFunc(
|
182 |
processor=feature_extractor,
|
183 |
padding=True,
|