Spaces:
Running
on
Zero
Running
on
Zero
wli3221134
commited on
Update dataset.py
Browse files- dataset.py +11 -9
dataset.py
CHANGED
@@ -6,21 +6,17 @@ import librosa
|
|
6 |
import numpy as np
|
7 |
|
8 |
class DemoDataset(Dataset):
|
9 |
-
def __init__(self, demonstration_paths, query_path, sample_rate=16000):
|
10 |
self.sample_rate = sample_rate
|
11 |
self.query_path = query_path
|
12 |
|
13 |
# Convert to list if single path
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
self.demonstration_paths = demonstration_paths
|
18 |
-
|
19 |
# Load feature extractor
|
20 |
-
self.feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/
|
21 |
|
22 |
-
print(f'Number of demonstration audios: {len(self.demonstration_paths)}')
|
23 |
-
print(f'Query audio: {self.query_path}')
|
24 |
|
25 |
def load_pad(self, path, max_length=64000):
|
26 |
"""Load and pad audio file"""
|
@@ -74,9 +70,12 @@ class DemoDataset(Dataset):
|
|
74 |
)
|
75 |
prompt_features.append(prompt_feature)
|
76 |
|
|
|
|
|
77 |
return {
|
78 |
'main_features': main_features,
|
79 |
'prompt_features': prompt_features,
|
|
|
80 |
'file_name': os.path.basename(self.query_path),
|
81 |
'file_path': self.query_path
|
82 |
}
|
@@ -114,9 +113,12 @@ def collate_fn(batch):
|
|
114 |
file_names = [item['file_name'] for item in batch]
|
115 |
file_paths = [item['file_path'] for item in batch]
|
116 |
|
|
|
|
|
117 |
return {
|
118 |
'main_features': main_features,
|
119 |
'prompt_features': prompt_features,
|
|
|
120 |
'file_names': file_names,
|
121 |
'file_paths': file_paths
|
122 |
}
|
|
|
6 |
import numpy as np
|
7 |
|
8 |
class DemoDataset(Dataset):
|
9 |
+
def __init__(self, demonstration_paths, demonstration_labels, query_path, sample_rate=16000):
|
10 |
self.sample_rate = sample_rate
|
11 |
self.query_path = query_path
|
12 |
|
13 |
# Convert to list if single path
|
14 |
+
self.demonstration_paths = demonstration_paths
|
15 |
+
self.demonstration_labels = [0 if label == 'bonafide' else 1 for label in demonstration_labels]
|
16 |
+
|
|
|
|
|
17 |
# Load feature extractor
|
18 |
+
self.feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
|
19 |
|
|
|
|
|
20 |
|
21 |
def load_pad(self, path, max_length=64000):
|
22 |
"""Load and pad audio file"""
|
|
|
70 |
)
|
71 |
prompt_features.append(prompt_feature)
|
72 |
|
73 |
+
prompt_labels = torch.tensor(self.demonstration_labels, dtype=torch.long)
|
74 |
+
|
75 |
return {
|
76 |
'main_features': main_features,
|
77 |
'prompt_features': prompt_features,
|
78 |
+
'prompt_labels': prompt_labels,
|
79 |
'file_name': os.path.basename(self.query_path),
|
80 |
'file_path': self.query_path
|
81 |
}
|
|
|
113 |
file_names = [item['file_name'] for item in batch]
|
114 |
file_paths = [item['file_path'] for item in batch]
|
115 |
|
116 |
+
prompt_labels = torch.tensor([item['prompt_labels'] for item in batch], dtype=torch.long)
|
117 |
+
|
118 |
return {
|
119 |
'main_features': main_features,
|
120 |
'prompt_features': prompt_features,
|
121 |
+
'prompt_labels': prompt_labels,
|
122 |
'file_names': file_names,
|
123 |
'file_paths': file_paths
|
124 |
}
|