File size: 4,084 Bytes
0874d87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import numpy as np
import torch
from torch.utils.data import Dataset
import librosa
from typing import List, Tuple
import shutil
import kagglehub
from transformers import Wav2Vec2Processor, Wav2Vec2Model
import subprocess
import zipfile
import os
# Constants (you may need to define these according to your requirements)
SAMPLE_RATE = 16000  # Define the sample rate for audio processing
DURATION = 3.0  # Duration of the audio in seconds

# Placeholder for waveform normalization
def normalize_waveform(audio: np.ndarray) -> torch.Tensor:
    # Convert to tensor if necessary
    if not isinstance(audio, torch.Tensor):
        audio = torch.tensor(audio, dtype=torch.float32)
    return (audio - torch.mean(audio)) / torch.std(audio)

class TESSRawWaveformDataset(Dataset):
    def __init__(self, root_path: str, transform=None):
        super().__init__()
        self.root_path = root_path
        self.audio_files = []
        self.labels = []
        self.emotions = ["happy", "sad", "angry", "neutral", "fear", "disgust", "surprise"]
        emotion_mapping = {e.lower(): idx for idx, e in enumerate(self.emotions)}
        self.download_dataset_if_not_exists()
        # Load file paths and labels from nested directories
        for root, dirs, files in os.walk(root_path):
            for file_name in files:
                if file_name.endswith(".wav"):
                    emotion_name = next(
                        (e for e in emotion_mapping if e in root.lower()), None
                    )
                    if emotion_name is not None:
                        self.audio_files.append(os.path.join(root, file_name))
                        self.labels.append(emotion_mapping[emotion_name])

        self.labels = np.array(self.labels, dtype=np.int64)
        self.transform = transform

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        # Load raw waveform and label
        audio_path = self.audio_files[idx]
        label = self.labels[idx]
        waveform = self.load_audio(audio_path)

        if self.transform:
            waveform = self.transform(waveform)

        return waveform, label

    @staticmethod
    def load_audio(audio_path: str) -> torch.Tensor:
        # Load audio and ensure it's at the correct sample rate
        audio, sr = librosa.load(audio_path, sr=SAMPLE_RATE, duration=DURATION)
        assert sr == SAMPLE_RATE, f"Sample rate mismatch: expected {SAMPLE_RATE}, got {sr}"
        return normalize_waveform(audio)

    def get_emotions(self) -> List[str]:
        return self.emotions

    def download_dataset_if_not_exists(self):
      if not os.path.exists(self.root_path):
          print(f"Dataset not found at {self.root_path}. Downloading...")

          # Ensure the destination directory exists
          os.makedirs(self.root_path, exist_ok=True)

          # Download dataset using curl
          dataset_zip_path = os.path.join(self.root_path, "toronto-emotional-speech-set-tess.zip")
          curl_command = [
              "curl",
              "-L",
              "-o",
              dataset_zip_path,
              "https://www.kaggle.com/api/v1/datasets/download/ejlok1/toronto-emotional-speech-set-tess",
          ]

          try:
              subprocess.run(curl_command, check=True)
              print(f"Dataset downloaded to {dataset_zip_path}.")

              # Extract the downloaded zip file
              with zipfile.ZipFile(dataset_zip_path, "r") as zip_ref:
                  zip_ref.extractall(self.root_path)
              print(f"Dataset extracted to {self.root_path}.")

              # Remove the zip file to save space
              os.remove(dataset_zip_path)
              print(f"Removed zip file: {dataset_zip_path}")

          except subprocess.CalledProcessError as e:
              print(f"Error occurred during dataset download: {e}")
              raise


# Example usage
# dataset = TESSRawWaveformDataset(root_path="./TESS", transform=None)
# print("Number of samples:", len(dataset))