Spaces:
Runtime error
Runtime error
Pavankalyan
commited on
Commit
·
ea630e0
1
Parent(s):
9fb7a19
Upload dataset.py with huggingface_hub
Browse files- dataset.py +92 -0
dataset.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pandas as pd
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torchaudio
|
6 |
+
import torch.nn.utils.rnn as rnn_utils
|
7 |
+
|
8 |
+
import whisper
|
9 |
+
|
10 |
+
def collate_fn(batch):
|
11 |
+
(seq, label) = zip(*batch)
|
12 |
+
seql = [x.reshape(-1,) for x in seq]
|
13 |
+
data = rnn_utils.pad_sequence(seql, batch_first=True, padding_value=0)
|
14 |
+
label = torch.tensor(list(label))
|
15 |
+
return data, label
|
16 |
+
|
17 |
+
def collate_mel_fn(batch):
|
18 |
+
(seq, label) = zip(*batch)
|
19 |
+
data = torch.stack([x.reshape(80, -1) for x in seq])
|
20 |
+
label = torch.tensor(list(label))
|
21 |
+
return data, label
|
22 |
+
|
23 |
+
class S2IDataset(torch.utils.data.Dataset):
|
24 |
+
def __init__(self, csv_path=None, wav_dir_path=None):
|
25 |
+
self.df = pd.read_csv(csv_path)
|
26 |
+
self.wav_dir = wav_dir_path
|
27 |
+
self.resmaple = torchaudio.transforms.Resample(8000, 16000)
|
28 |
+
|
29 |
+
def __len__(self):
|
30 |
+
return len(self.df)
|
31 |
+
|
32 |
+
def __getitem__(self, idx):
|
33 |
+
if torch.is_tensor(idx):
|
34 |
+
idx = idx.tolist()
|
35 |
+
|
36 |
+
row = self.df.iloc[idx]
|
37 |
+
intent_class = row["intent_class"]
|
38 |
+
wav_path = os.path.join(self.wav_dir, row["audio_path"])
|
39 |
+
speaker_id = row["speaker_id"]
|
40 |
+
template = row["template"]
|
41 |
+
|
42 |
+
wav_tensor, _= torchaudio.load(wav_path)
|
43 |
+
wav_tensor = self.resmaple(wav_tensor)
|
44 |
+
intent_class = int(intent_class)
|
45 |
+
return wav_tensor, intent_class
|
46 |
+
|
47 |
+
class S2IMELDataset(torch.utils.data.Dataset):
|
48 |
+
def __init__(self, csv_path=None, wav_dir_path=None):
|
49 |
+
self.df = pd.read_csv(csv_path)
|
50 |
+
self.wav_dir = wav_dir_path
|
51 |
+
self.resmaple = torchaudio.transforms.Resample(8000, 16000)
|
52 |
+
|
53 |
+
def __len__(self):
|
54 |
+
return len(self.df)
|
55 |
+
|
56 |
+
def __getitem__(self, idx):
|
57 |
+
if torch.is_tensor(idx):
|
58 |
+
idx = idx.tolist()
|
59 |
+
|
60 |
+
row = self.df.iloc[idx]
|
61 |
+
intent_class = row["intent_class"]
|
62 |
+
wav_path = os.path.join(self.wav_dir, row["audio_path"])
|
63 |
+
speaker_id = row["speaker_id"]
|
64 |
+
template = row["template"]
|
65 |
+
|
66 |
+
wav_tensor, _= torchaudio.load(wav_path)
|
67 |
+
wav_tensor = self.resmaple(wav_tensor)
|
68 |
+
|
69 |
+
wav_tensor = whisper.pad_or_trim(wav_tensor.flatten())
|
70 |
+
mel = whisper.log_mel_spectrogram(wav_tensor)
|
71 |
+
|
72 |
+
intent_class = int(intent_class)
|
73 |
+
return mel, intent_class
|
74 |
+
|
75 |
+
if __name__ == "__main__":
|
76 |
+
dataset = S2IMELDataset(
|
77 |
+
csv_path="/root/Speech2Intent/dataset/speech-to-intent/train.csv",
|
78 |
+
wav_dir_path="/root/Speech2Intent/dataset/speech-to-intent/",
|
79 |
+
sr=16000)
|
80 |
+
wav_tensor, intent_class = dataset[0]
|
81 |
+
print(wav_tensor.shape, intent_class)
|
82 |
+
|
83 |
+
trainloader = torch.utils.data.DataLoader(
|
84 |
+
dataset,
|
85 |
+
batch_size=3,
|
86 |
+
shuffle=True,
|
87 |
+
num_workers=4,
|
88 |
+
collate_fn = collate_mel_fn,
|
89 |
+
)
|
90 |
+
x, y = next(iter(trainloader))
|
91 |
+
print(x.shape)
|
92 |
+
print(y.shape)
|