Pavankalyan commited on
Commit
ea630e0
·
1 Parent(s): 9fb7a19

Upload dataset.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dataset.py +92 -0
dataset.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+
4
+ import torch
5
+ import torchaudio
6
+ import torch.nn.utils.rnn as rnn_utils
7
+
8
+ import whisper
9
+
10
+ def collate_fn(batch):
11
+ (seq, label) = zip(*batch)
12
+ seql = [x.reshape(-1,) for x in seq]
13
+ data = rnn_utils.pad_sequence(seql, batch_first=True, padding_value=0)
14
+ label = torch.tensor(list(label))
15
+ return data, label
16
+
17
+ def collate_mel_fn(batch):
18
+ (seq, label) = zip(*batch)
19
+ data = torch.stack([x.reshape(80, -1) for x in seq])
20
+ label = torch.tensor(list(label))
21
+ return data, label
22
+
23
+ class S2IDataset(torch.utils.data.Dataset):
24
+ def __init__(self, csv_path=None, wav_dir_path=None):
25
+ self.df = pd.read_csv(csv_path)
26
+ self.wav_dir = wav_dir_path
27
+ self.resmaple = torchaudio.transforms.Resample(8000, 16000)
28
+
29
+ def __len__(self):
30
+ return len(self.df)
31
+
32
+ def __getitem__(self, idx):
33
+ if torch.is_tensor(idx):
34
+ idx = idx.tolist()
35
+
36
+ row = self.df.iloc[idx]
37
+ intent_class = row["intent_class"]
38
+ wav_path = os.path.join(self.wav_dir, row["audio_path"])
39
+ speaker_id = row["speaker_id"]
40
+ template = row["template"]
41
+
42
+ wav_tensor, _= torchaudio.load(wav_path)
43
+ wav_tensor = self.resmaple(wav_tensor)
44
+ intent_class = int(intent_class)
45
+ return wav_tensor, intent_class
46
+
47
+ class S2IMELDataset(torch.utils.data.Dataset):
48
+ def __init__(self, csv_path=None, wav_dir_path=None):
49
+ self.df = pd.read_csv(csv_path)
50
+ self.wav_dir = wav_dir_path
51
+ self.resmaple = torchaudio.transforms.Resample(8000, 16000)
52
+
53
+ def __len__(self):
54
+ return len(self.df)
55
+
56
+ def __getitem__(self, idx):
57
+ if torch.is_tensor(idx):
58
+ idx = idx.tolist()
59
+
60
+ row = self.df.iloc[idx]
61
+ intent_class = row["intent_class"]
62
+ wav_path = os.path.join(self.wav_dir, row["audio_path"])
63
+ speaker_id = row["speaker_id"]
64
+ template = row["template"]
65
+
66
+ wav_tensor, _= torchaudio.load(wav_path)
67
+ wav_tensor = self.resmaple(wav_tensor)
68
+
69
+ wav_tensor = whisper.pad_or_trim(wav_tensor.flatten())
70
+ mel = whisper.log_mel_spectrogram(wav_tensor)
71
+
72
+ intent_class = int(intent_class)
73
+ return mel, intent_class
74
+
75
+ if __name__ == "__main__":
76
+ dataset = S2IMELDataset(
77
+ csv_path="/root/Speech2Intent/dataset/speech-to-intent/train.csv",
78
+ wav_dir_path="/root/Speech2Intent/dataset/speech-to-intent/",
79
+ sr=16000)
80
+ wav_tensor, intent_class = dataset[0]
81
+ print(wav_tensor.shape, intent_class)
82
+
83
+ trainloader = torch.utils.data.DataLoader(
84
+ dataset,
85
+ batch_size=3,
86
+ shuffle=True,
87
+ num_workers=4,
88
+ collate_fn = collate_mel_fn,
89
+ )
90
+ x, y = next(iter(trainloader))
91
+ print(x.shape)
92
+ print(y.shape)