File size: 2,478 Bytes
0874d87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
import torch.nn as nn
import os
import shutil
def normalize_ratios(ratios):
    total = sum(ratios)
    return [r / total for r in ratios]

from torch.nn.utils.rnn import pad_sequence

def collate_fn_transformer(batch):
    """
    Custom collate function to handle variable-length raw waveform inputs.
    Args:
        batch: List of tuples (tensor, label), where tensor has shape [sequence_length].
    Returns:
        padded_waveforms: Padded tensor of shape [batch_size, max_seq_len].
        attention_mask: Attention mask for padded sequences.
        labels: Tensor of shape [batch_size].
    """
    # Separate waveforms and labels
    waveforms, labels = zip(*batch)

    # Ensure waveforms are 1D tensors
    waveforms = [torch.tensor(waveform).squeeze() for waveform in waveforms]

    # Pad sequences to the same length
    padded_waveforms = pad_sequence(waveforms, batch_first=True)  # [batch_size, max_seq_len]
    

    # Create attention mask
    attention_mask = (padded_waveforms != 0).long()  # Mask for non-padded values
    # In the training loop or DataLoader debug


    # Convert labels to a tensor
    labels = torch.tensor(labels, dtype=torch.long)

    return padded_waveforms, attention_mask, labels

def collate_fn(batch):
    inputs, targets, input_lengths, target_lengths = zip(*batch)
    inputs = torch.stack(inputs)  # Convert list of tensors to a batch tensor
    targets = torch.cat(targets)  # Flatten target sequences
    input_lengths = torch.tensor(input_lengths, dtype=torch.long)
    target_lengths = torch.tensor(target_lengths, dtype=torch.long)
    return inputs, targets, input_lengths, target_lengths



def save_test_data(test_dataset, dataset, save_dir):
    
    if os.path.exists(save_dir):
        shutil.rmtree(save_dir)  # Delete the existing directory and its contents
        print(f"Existing test data directory '{save_dir}' removed.")

    os.makedirs(save_dir, exist_ok=True)

    for idx in test_dataset.indices:
        audio_file_path = dataset.audio_files[idx]  # Assuming dataset has `audio_files` attribute
        label = dataset.labels[idx]  # Assuming dataset has `labels` attribute

        # Create a directory for the label if it doesn't exist
        label_dir = os.path.join(save_dir, str(label))
        os.makedirs(label_dir, exist_ok=True)

        # Copy the audio file to the label directory
        shutil.copy(audio_file_path, label_dir)

    print(f"Test data saved in {save_dir}")