DeepfakeDetection / dataset.py
wli3221134's picture
Update dataset.py
6e8a76c verified
import torch
from torch.utils.data import Dataset
from transformers import AutoFeatureExtractor
import os
import librosa
import numpy as np
class DemoDataset(Dataset):
def __init__(self, demonstration_paths, demonstration_labels, query_path, sample_rate=16000):
self.sample_rate = sample_rate
self.query_path = query_path
# Convert to list if single path
self.demonstration_paths = demonstration_paths
self.demonstration_labels = [0 if label == 'bonafide' else 1 for label in demonstration_labels]
# Load feature extractor
self.feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
def load_pad(self, path, max_length=64000):
"""Load and pad audio file"""
X, sr = librosa.load(path, sr=self.sample_rate)
X = self.pad(X, max_length)
return X
def pad(self, x, max_len=64000):
"""Pad audio to fixed length"""
x_len = x.shape[0]
if x_len >= max_len:
return x[:max_len]
pad_length = max_len - x_len
return np.concatenate([x, np.zeros(pad_length)], axis=0)
def __len__(self):
return 1 # Only one query audio
def __getitem__(self, idx):
# Load query audio
query_waveform = self.load_pad(self.query_path)
query_waveform = torch.from_numpy(query_waveform).float()
if len(query_waveform.shape) == 1:
query_waveform = query_waveform.unsqueeze(0)
# Extract features for query audio
main_features = self.feature_extractor(
query_waveform,
sampling_rate=self.sample_rate,
padding=True,
return_attention_mask=True,
return_tensors="pt"
)
# Process demonstration audios
prompt_features = []
for demo_path in self.demonstration_paths:
# Load demonstration audio
demo_waveform = self.load_pad(demo_path)
demo_waveform = torch.from_numpy(demo_waveform).float()
if len(demo_waveform.shape) == 1:
demo_waveform = demo_waveform.unsqueeze(0)
# Extract features
prompt_feature = self.feature_extractor(
demo_waveform,
sampling_rate=self.sample_rate,
padding=True,
return_attention_mask=True,
return_tensors="pt"
)
prompt_features.append(prompt_feature)
prompt_labels = torch.tensor([self.demonstration_labels], dtype=torch.long)
return {
'main_features': main_features,
'prompt_features': prompt_features,
'prompt_labels': prompt_labels,
'file_name': os.path.basename(self.query_path),
'file_path': self.query_path
}
def collate_fn(batch):
"""
Collate function for dataloader
Args:
batch: List containing dictionaries with:
- main_features: feature extractor output
- prompt_features: list of feature extractor outputs
- file_name: file name
- file_path: file path
"""
batch_size = len(batch)
# Process main features
main_features_keys = batch[0]['main_features'].keys()
main_features = {}
for key in main_features_keys:
main_features[key] = torch.cat([item['main_features'][key] for item in batch], dim=0)
# Get number of prompts
num_prompts = len(batch[0]['prompt_features'])
# Process prompt features
prompt_features = []
for i in range(num_prompts):
prompt_feature = {}
for key in main_features_keys:
prompt_feature[key] = torch.cat([item['prompt_features'][i][key] for item in batch], dim=0)
prompt_features.append(prompt_feature)
# Collect file names and paths
file_names = [item['file_name'] for item in batch]
file_paths = [item['file_path'] for item in batch]
# 确保 prompt_labels 的形状正确 [batch_size, num_prompts]
prompt_labels = torch.cat([item['prompt_labels'] for item in batch], dim=0)
return {
'main_features': main_features,
'prompt_features': prompt_features,
'prompt_labels': prompt_labels,
'file_names': file_names,
'file_paths': file_paths
}
if __name__ == '__main__':
# Test the dataset
demo_paths = ["examples/demo1.wav", "examples/demo2.wav"]
query_path = "examples/query.wav"
dataset = DemoDataset(demo_paths, query_path)
print(dataset[0])