Spaces:
Sleeping
Sleeping
import os | |
import zipfile | |
import requests | |
from tqdm import tqdm | |
from typing import List, Tuple | |
import numpy as np | |
from torch.utils.data import Dataset | |
import librosa | |
import torch | |
from torch.nn.utils.rnn import pad_sequence | |
SAMPLE_RATE = 22050 | |
DURATION = 1.4 # seconds | |
class CTCEmodbDataset(Dataset): | |
__url__ = "http://www.emodb.bilderbar.info/download/download.zip" | |
__labels__ = ("angry", "happy", "neutral", "sad") | |
__suffixes__ = { | |
"angry": ["Wa", "Wb", "Wc", "Wd"], | |
"happy": ["Fa", "Fb", "Fc", "Fd"], | |
"neutral": ["Na", "Nb", "Nc", "Nd"], | |
"sad": ["Ta", "Tb", "Tc", "Td"] | |
} | |
def __init__(self, root_path: str = './data/emodb', transform=None): | |
super().__init__() | |
self.root_path = root_path | |
self.audio_root_path = os.path.join(root_path, "wav") | |
# Ensure the dataset is downloaded | |
self._ensure_dataset() | |
ids = [] | |
targets = [] | |
for audio_file in os.listdir(self.audio_root_path): | |
f_name, ext = os.path.splitext(audio_file) | |
if ext != ".wav": | |
continue | |
suffix = f_name[-2:] | |
for label, suffixes in self.__suffixes__.items(): | |
if suffix in suffixes: | |
ids.append(os.path.join(self.audio_root_path, audio_file)) | |
targets.append(self.label2id(label)) # Store as integers | |
break | |
self.ids = ids | |
self.targets = targets # Target sequences as a list of lists | |
self.transform = transform | |
def _ensure_dataset(self): | |
""" | |
Ensures the dataset is downloaded and extracted. | |
""" | |
if not os.path.isdir(self.audio_root_path): | |
print(f"Dataset not found at {self.audio_root_path}. Downloading...") | |
self._download_and_extract() | |
def _download_and_extract(self): | |
""" | |
Downloads and extracts the dataset zip file. | |
""" | |
os.makedirs(self.root_path, exist_ok=True) | |
zip_path = os.path.join(self.root_path, "emodb.zip") | |
with requests.get(self.__url__, stream=True) as r: | |
r.raise_for_status() | |
total_size = int(r.headers.get("content-length", 0)) | |
with open(zip_path, "wb") as f, tqdm( | |
desc="Downloading EMO-DB dataset", | |
total=total_size, | |
unit="B", | |
unit_scale=True, | |
unit_divisor=1024, | |
) as bar: | |
for chunk in r.iter_content(chunk_size=8192): | |
f.write(chunk) | |
bar.update(len(chunk)) | |
print("Extracting dataset...") | |
with zipfile.ZipFile(zip_path, "r") as zip_ref: | |
zip_ref.extractall(self.root_path) | |
os.remove(zip_path) | |
def __len__(self): | |
return len(self.ids) | |
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, int, int]: | |
""" | |
Returns: | |
x (torch.Tensor): Input sequence (audio features or waveform) | |
y (torch.Tensor): Target sequence (labels or tokenized transcription) | |
input_length (int): Length of input sequence | |
target_length (int): Length of target sequence | |
""" | |
target = torch.tensor([self.targets[idx]], dtype=torch.long) | |
audio = self.load_audio(self.ids[idx]) # Should return a numpy array | |
if self.transform: | |
audio = self.transform(audio) | |
# Input length (for CTC) | |
input_length = audio.shape[-1] # Last dimension is the time dimension | |
target_length = len(target) # Length of target sequence | |
return audio, target, input_length, target_length | |
def id2label(idx: int) -> str: | |
return CTCEmodbDataset.__labels__[idx] | |
def label2id(label: str) -> int: | |
if label not in CTCEmodbDataset.__labels__: | |
raise ValueError(f"Unknown label: {label}") | |
return CTCEmodbDataset.__labels__.index(label) | |
def load_audio(audio_file_path: str) -> torch.Tensor: | |
audio, sr = librosa.load(audio_file_path, sr=SAMPLE_RATE, duration=DURATION) | |
assert SAMPLE_RATE == sr, "broken audio file" | |
return torch.tensor(audio, dtype=torch.float32) | |
def get_labels() -> List[str]: | |
return list(CTCEmodbDataset.__labels__) | |