Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch.utils.data import Dataset | |
from transformers import AutoFeatureExtractor | |
import os | |
import librosa | |
import numpy as np | |
class DemoDataset(Dataset): | |
def __init__(self, demonstration_paths, demonstration_labels, query_path, sample_rate=16000): | |
self.sample_rate = sample_rate | |
self.query_path = query_path | |
# Convert to list if single path | |
self.demonstration_paths = demonstration_paths | |
self.demonstration_labels = [0 if label == 'bonafide' else 1 for label in demonstration_labels] | |
# Load feature extractor | |
self.feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0") | |
def load_pad(self, path, max_length=64000): | |
"""Load and pad audio file""" | |
X, sr = librosa.load(path, sr=self.sample_rate) | |
X = self.pad(X, max_length) | |
return X | |
def pad(self, x, max_len=64000): | |
"""Pad audio to fixed length""" | |
x_len = x.shape[0] | |
if x_len >= max_len: | |
return x[:max_len] | |
pad_length = max_len - x_len | |
return np.concatenate([x, np.zeros(pad_length)], axis=0) | |
def __len__(self): | |
return 1 # Only one query audio | |
def __getitem__(self, idx): | |
# Load query audio | |
query_waveform = self.load_pad(self.query_path) | |
query_waveform = torch.from_numpy(query_waveform).float() | |
if len(query_waveform.shape) == 1: | |
query_waveform = query_waveform.unsqueeze(0) | |
# Extract features for query audio | |
main_features = self.feature_extractor( | |
query_waveform, | |
sampling_rate=self.sample_rate, | |
padding=True, | |
return_attention_mask=True, | |
return_tensors="pt" | |
) | |
# Process demonstration audios | |
prompt_features = [] | |
for demo_path in self.demonstration_paths: | |
# Load demonstration audio | |
demo_waveform = self.load_pad(demo_path) | |
demo_waveform = torch.from_numpy(demo_waveform).float() | |
if len(demo_waveform.shape) == 1: | |
demo_waveform = demo_waveform.unsqueeze(0) | |
# Extract features | |
prompt_feature = self.feature_extractor( | |
demo_waveform, | |
sampling_rate=self.sample_rate, | |
padding=True, | |
return_attention_mask=True, | |
return_tensors="pt" | |
) | |
prompt_features.append(prompt_feature) | |
prompt_labels = torch.tensor([self.demonstration_labels], dtype=torch.long) | |
return { | |
'main_features': main_features, | |
'prompt_features': prompt_features, | |
'prompt_labels': prompt_labels, | |
'file_name': os.path.basename(self.query_path), | |
'file_path': self.query_path | |
} | |
def collate_fn(batch): | |
""" | |
Collate function for dataloader | |
Args: | |
batch: List containing dictionaries with: | |
- main_features: feature extractor output | |
- prompt_features: list of feature extractor outputs | |
- file_name: file name | |
- file_path: file path | |
""" | |
batch_size = len(batch) | |
# Process main features | |
main_features_keys = batch[0]['main_features'].keys() | |
main_features = {} | |
for key in main_features_keys: | |
main_features[key] = torch.cat([item['main_features'][key] for item in batch], dim=0) | |
# Get number of prompts | |
num_prompts = len(batch[0]['prompt_features']) | |
# Process prompt features | |
prompt_features = [] | |
for i in range(num_prompts): | |
prompt_feature = {} | |
for key in main_features_keys: | |
prompt_feature[key] = torch.cat([item['prompt_features'][i][key] for item in batch], dim=0) | |
prompt_features.append(prompt_feature) | |
# Collect file names and paths | |
file_names = [item['file_name'] for item in batch] | |
file_paths = [item['file_path'] for item in batch] | |
# 确保 prompt_labels 的形状正确 [batch_size, num_prompts] | |
prompt_labels = torch.cat([item['prompt_labels'] for item in batch], dim=0) | |
return { | |
'main_features': main_features, | |
'prompt_features': prompt_features, | |
'prompt_labels': prompt_labels, | |
'file_names': file_names, | |
'file_paths': file_paths | |
} | |
if __name__ == '__main__': | |
# Test the dataset | |
demo_paths = ["examples/demo1.wav", "examples/demo2.wav"] | |
query_path = "examples/query.wav" | |
dataset = DemoDataset(demo_paths, query_path) | |
print(dataset[0]) | |