emotion-detection / datasets /ctc_audio_dataclass.py
saeedbenadeeb's picture
First commit
0874d87
import os
import zipfile
import requests
from tqdm import tqdm
from typing import List, Tuple
import numpy as np
from torch.utils.data import Dataset
import librosa
import torch
from torch.nn.utils.rnn import pad_sequence
SAMPLE_RATE = 22050
DURATION = 1.4 # seconds
class CTCEmodbDataset(Dataset):
__url__ = "http://www.emodb.bilderbar.info/download/download.zip"
__labels__ = ("angry", "happy", "neutral", "sad")
__suffixes__ = {
"angry": ["Wa", "Wb", "Wc", "Wd"],
"happy": ["Fa", "Fb", "Fc", "Fd"],
"neutral": ["Na", "Nb", "Nc", "Nd"],
"sad": ["Ta", "Tb", "Tc", "Td"]
}
def __init__(self, root_path: str = './data/emodb', transform=None):
super().__init__()
self.root_path = root_path
self.audio_root_path = os.path.join(root_path, "wav")
# Ensure the dataset is downloaded
self._ensure_dataset()
ids = []
targets = []
for audio_file in os.listdir(self.audio_root_path):
f_name, ext = os.path.splitext(audio_file)
if ext != ".wav":
continue
suffix = f_name[-2:]
for label, suffixes in self.__suffixes__.items():
if suffix in suffixes:
ids.append(os.path.join(self.audio_root_path, audio_file))
targets.append(self.label2id(label)) # Store as integers
break
self.ids = ids
self.targets = targets # Target sequences as a list of lists
self.transform = transform
def _ensure_dataset(self):
"""
Ensures the dataset is downloaded and extracted.
"""
if not os.path.isdir(self.audio_root_path):
print(f"Dataset not found at {self.audio_root_path}. Downloading...")
self._download_and_extract()
def _download_and_extract(self):
"""
Downloads and extracts the dataset zip file.
"""
os.makedirs(self.root_path, exist_ok=True)
zip_path = os.path.join(self.root_path, "emodb.zip")
with requests.get(self.__url__, stream=True) as r:
r.raise_for_status()
total_size = int(r.headers.get("content-length", 0))
with open(zip_path, "wb") as f, tqdm(
desc="Downloading EMO-DB dataset",
total=total_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as bar:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
bar.update(len(chunk))
print("Extracting dataset...")
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(self.root_path)
os.remove(zip_path)
def __len__(self):
return len(self.ids)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
"""
Returns:
x (torch.Tensor): Input sequence (audio features or waveform)
y (torch.Tensor): Target sequence (labels or tokenized transcription)
input_length (int): Length of input sequence
target_length (int): Length of target sequence
"""
target = torch.tensor([self.targets[idx]], dtype=torch.long)
audio = self.load_audio(self.ids[idx]) # Should return a numpy array
if self.transform:
audio = self.transform(audio)
# Input length (for CTC)
input_length = audio.shape[-1] # Last dimension is the time dimension
target_length = len(target) # Length of target sequence
return audio, target, input_length, target_length
@staticmethod
def id2label(idx: int) -> str:
return CTCEmodbDataset.__labels__[idx]
@staticmethod
def label2id(label: str) -> int:
if label not in CTCEmodbDataset.__labels__:
raise ValueError(f"Unknown label: {label}")
return CTCEmodbDataset.__labels__.index(label)
@staticmethod
def load_audio(audio_file_path: str) -> torch.Tensor:
audio, sr = librosa.load(audio_file_path, sr=SAMPLE_RATE, duration=DURATION)
assert SAMPLE_RATE == sr, "broken audio file"
return torch.tensor(audio, dtype=torch.float32)
@staticmethod
def get_labels() -> List[str]:
return list(CTCEmodbDataset.__labels__)