emotion-detection / utils /helper_functions.py
saeedbenadeeb's picture
First commit
0874d87
import torch
import torch.nn as nn
import os
import shutil
def normalize_ratios(ratios):
total = sum(ratios)
return [r / total for r in ratios]
from torch.nn.utils.rnn import pad_sequence
def collate_fn_transformer(batch):
"""
Custom collate function to handle variable-length raw waveform inputs.
Args:
batch: List of tuples (tensor, label), where tensor has shape [sequence_length].
Returns:
padded_waveforms: Padded tensor of shape [batch_size, max_seq_len].
attention_mask: Attention mask for padded sequences.
labels: Tensor of shape [batch_size].
"""
# Separate waveforms and labels
waveforms, labels = zip(*batch)
# Ensure waveforms are 1D tensors
waveforms = [torch.tensor(waveform).squeeze() for waveform in waveforms]
# Pad sequences to the same length
padded_waveforms = pad_sequence(waveforms, batch_first=True) # [batch_size, max_seq_len]
# Create attention mask
attention_mask = (padded_waveforms != 0).long() # Mask for non-padded values
# In the training loop or DataLoader debug
# Convert labels to a tensor
labels = torch.tensor(labels, dtype=torch.long)
return padded_waveforms, attention_mask, labels
def collate_fn(batch):
inputs, targets, input_lengths, target_lengths = zip(*batch)
inputs = torch.stack(inputs) # Convert list of tensors to a batch tensor
targets = torch.cat(targets) # Flatten target sequences
input_lengths = torch.tensor(input_lengths, dtype=torch.long)
target_lengths = torch.tensor(target_lengths, dtype=torch.long)
return inputs, targets, input_lengths, target_lengths
def save_test_data(test_dataset, dataset, save_dir):
if os.path.exists(save_dir):
shutil.rmtree(save_dir) # Delete the existing directory and its contents
print(f"Existing test data directory '{save_dir}' removed.")
os.makedirs(save_dir, exist_ok=True)
for idx in test_dataset.indices:
audio_file_path = dataset.audio_files[idx] # Assuming dataset has `audio_files` attribute
label = dataset.labels[idx] # Assuming dataset has `labels` attribute
# Create a directory for the label if it doesn't exist
label_dir = os.path.join(save_dir, str(label))
os.makedirs(label_dir, exist_ok=True)
# Copy the audio file to the label directory
shutil.copy(audio_file_path, label_dir)
print(f"Test data saved in {save_dir}")