wli3221134 commited on
Commit
f96cfa2
·
verified ·
1 Parent(s): 18b7b66

Update dataset.py

Browse files
Files changed (1) hide show
  1. dataset.py +11 -9
dataset.py CHANGED
@@ -6,21 +6,17 @@ import librosa
6
  import numpy as np
7
 
8
  class DemoDataset(Dataset):
9
- def __init__(self, demonstration_paths, query_path, sample_rate=16000):
10
  self.sample_rate = sample_rate
11
  self.query_path = query_path
12
 
13
  # Convert to list if single path
14
- if isinstance(demonstration_paths, str):
15
- self.demonstration_paths = [demonstration_paths]
16
- else:
17
- self.demonstration_paths = demonstration_paths
18
-
19
  # Load feature extractor
20
- self.feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
21
 
22
- print(f'Number of demonstration audios: {len(self.demonstration_paths)}')
23
- print(f'Query audio: {self.query_path}')
24
 
25
  def load_pad(self, path, max_length=64000):
26
  """Load and pad audio file"""
@@ -74,9 +70,12 @@ class DemoDataset(Dataset):
74
  )
75
  prompt_features.append(prompt_feature)
76
 
 
 
77
  return {
78
  'main_features': main_features,
79
  'prompt_features': prompt_features,
 
80
  'file_name': os.path.basename(self.query_path),
81
  'file_path': self.query_path
82
  }
@@ -114,9 +113,12 @@ def collate_fn(batch):
114
  file_names = [item['file_name'] for item in batch]
115
  file_paths = [item['file_path'] for item in batch]
116
 
 
 
117
  return {
118
  'main_features': main_features,
119
  'prompt_features': prompt_features,
 
120
  'file_names': file_names,
121
  'file_paths': file_paths
122
  }
 
6
  import numpy as np
7
 
8
  class DemoDataset(Dataset):
9
+ def __init__(self, demonstration_paths, demonstration_labels, query_path, sample_rate=16000):
10
  self.sample_rate = sample_rate
11
  self.query_path = query_path
12
 
13
  # Convert to list if single path
14
+ self.demonstration_paths = demonstration_paths
15
+ self.demonstration_labels = [0 if label == 'bonafide' else 1 for label in demonstration_labels]
16
+
 
 
17
  # Load feature extractor
18
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
19
 
 
 
20
 
21
  def load_pad(self, path, max_length=64000):
22
  """Load and pad audio file"""
 
70
  )
71
  prompt_features.append(prompt_feature)
72
 
73
+ prompt_labels = torch.tensor(self.demonstration_labels, dtype=torch.long)
74
+
75
  return {
76
  'main_features': main_features,
77
  'prompt_features': prompt_features,
78
+ 'prompt_labels': prompt_labels,
79
  'file_name': os.path.basename(self.query_path),
80
  'file_path': self.query_path
81
  }
 
113
  file_names = [item['file_name'] for item in batch]
114
  file_paths = [item['file_path'] for item in batch]
115
 
116
+ prompt_labels = torch.tensor([item['prompt_labels'] for item in batch], dtype=torch.long)
117
+
118
  return {
119
  'main_features': main_features,
120
  'prompt_features': prompt_features,
121
+ 'prompt_labels': prompt_labels,
122
  'file_names': file_names,
123
  'file_paths': file_paths
124
  }