Spaces:
Sleeping
Sleeping
File size: 2,478 Bytes
0874d87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import torch
import torch.nn as nn
import os
import shutil
def normalize_ratios(ratios):
total = sum(ratios)
return [r / total for r in ratios]
from torch.nn.utils.rnn import pad_sequence
def collate_fn_transformer(batch):
"""
Custom collate function to handle variable-length raw waveform inputs.
Args:
batch: List of tuples (tensor, label), where tensor has shape [sequence_length].
Returns:
padded_waveforms: Padded tensor of shape [batch_size, max_seq_len].
attention_mask: Attention mask for padded sequences.
labels: Tensor of shape [batch_size].
"""
# Separate waveforms and labels
waveforms, labels = zip(*batch)
# Ensure waveforms are 1D tensors
waveforms = [torch.tensor(waveform).squeeze() for waveform in waveforms]
# Pad sequences to the same length
padded_waveforms = pad_sequence(waveforms, batch_first=True) # [batch_size, max_seq_len]
# Create attention mask
attention_mask = (padded_waveforms != 0).long() # Mask for non-padded values
# In the training loop or DataLoader debug
# Convert labels to a tensor
labels = torch.tensor(labels, dtype=torch.long)
return padded_waveforms, attention_mask, labels
def collate_fn(batch):
inputs, targets, input_lengths, target_lengths = zip(*batch)
inputs = torch.stack(inputs) # Convert list of tensors to a batch tensor
targets = torch.cat(targets) # Flatten target sequences
input_lengths = torch.tensor(input_lengths, dtype=torch.long)
target_lengths = torch.tensor(target_lengths, dtype=torch.long)
return inputs, targets, input_lengths, target_lengths
def save_test_data(test_dataset, dataset, save_dir):
if os.path.exists(save_dir):
shutil.rmtree(save_dir) # Delete the existing directory and its contents
print(f"Existing test data directory '{save_dir}' removed.")
os.makedirs(save_dir, exist_ok=True)
for idx in test_dataset.indices:
audio_file_path = dataset.audio_files[idx] # Assuming dataset has `audio_files` attribute
label = dataset.labels[idx] # Assuming dataset has `labels` attribute
# Create a directory for the label if it doesn't exist
label_dir = os.path.join(save_dir, str(label))
os.makedirs(label_dir, exist_ok=True)
# Copy the audio file to the label directory
shutil.copy(audio_file_path, label_dir)
print(f"Test data saved in {save_dir}") |