Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import os | |
import shutil | |
def normalize_ratios(ratios): | |
total = sum(ratios) | |
return [r / total for r in ratios] | |
from torch.nn.utils.rnn import pad_sequence | |
def collate_fn_transformer(batch): | |
""" | |
Custom collate function to handle variable-length raw waveform inputs. | |
Args: | |
batch: List of tuples (tensor, label), where tensor has shape [sequence_length]. | |
Returns: | |
padded_waveforms: Padded tensor of shape [batch_size, max_seq_len]. | |
attention_mask: Attention mask for padded sequences. | |
labels: Tensor of shape [batch_size]. | |
""" | |
# Separate waveforms and labels | |
waveforms, labels = zip(*batch) | |
# Ensure waveforms are 1D tensors | |
waveforms = [torch.tensor(waveform).squeeze() for waveform in waveforms] | |
# Pad sequences to the same length | |
padded_waveforms = pad_sequence(waveforms, batch_first=True) # [batch_size, max_seq_len] | |
# Create attention mask | |
attention_mask = (padded_waveforms != 0).long() # Mask for non-padded values | |
# In the training loop or DataLoader debug | |
# Convert labels to a tensor | |
labels = torch.tensor(labels, dtype=torch.long) | |
return padded_waveforms, attention_mask, labels | |
def collate_fn(batch): | |
inputs, targets, input_lengths, target_lengths = zip(*batch) | |
inputs = torch.stack(inputs) # Convert list of tensors to a batch tensor | |
targets = torch.cat(targets) # Flatten target sequences | |
input_lengths = torch.tensor(input_lengths, dtype=torch.long) | |
target_lengths = torch.tensor(target_lengths, dtype=torch.long) | |
return inputs, targets, input_lengths, target_lengths | |
def save_test_data(test_dataset, dataset, save_dir): | |
if os.path.exists(save_dir): | |
shutil.rmtree(save_dir) # Delete the existing directory and its contents | |
print(f"Existing test data directory '{save_dir}' removed.") | |
os.makedirs(save_dir, exist_ok=True) | |
for idx in test_dataset.indices: | |
audio_file_path = dataset.audio_files[idx] # Assuming dataset has `audio_files` attribute | |
label = dataset.labels[idx] # Assuming dataset has `labels` attribute | |
# Create a directory for the label if it doesn't exist | |
label_dir = os.path.join(save_dir, str(label)) | |
os.makedirs(label_dir, exist_ok=True) | |
# Copy the audio file to the label directory | |
shutil.copy(audio_file_path, label_dir) | |
print(f"Test data saved in {save_dir}") |