Spaces:
Sleeping
Sleeping
import os | |
import glob | |
import torch | |
import torchaudio | |
import torchvision | |
from torch.utils.data import Dataset | |
from concurrent.futures import ThreadPoolExecutor | |
from preprocess import process_audio_data, process_image_data, resample_rate | |
class PreprocessedDataset(Dataset): | |
def __init__(self, data_dir): | |
self.data_dir = data_dir | |
self.samples = [ | |
os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".pt") | |
] | |
def __len__(self): | |
return len(self.samples) | |
def __getitem__(self, idx): | |
sample_path = self.samples[idx] | |
mfcc, image, label = torch.load(sample_path) | |
# Process data | |
mfcc = process_audio_data(mfcc, resample_rate) | |
image = process_image_data(image) | |
return mfcc, image, label | |
def load_audio_file(audio_path): | |
if not os.path.exists(audio_path): | |
raise FileNotFoundError(f"Audio file not found: {audio_path}") | |
try: | |
# Try the default torchaudio loader first | |
waveform, sample_rate = torchaudio.load(audio_path) | |
except Exception as e: | |
print(f"Warning: Could not load {audio_path} with torchaudio: {e}") | |
# Fall back to librosa (you'll need to install it: pip install librosa) | |
try: | |
import librosa | |
import numpy as np | |
waveform_np, sample_rate = librosa.load(audio_path, sr=None) | |
# Convert to torch tensor with shape [1, length] to match torchaudio format | |
waveform = torch.from_numpy(waveform_np[np.newaxis, :]).float() | |
print(f"Successfully loaded with librosa: {audio_path}") | |
except Exception as final_e: | |
raise RuntimeError(f"Failed to load audio file {audio_path} with all available methods: {final_e}") | |
return waveform, sample_rate | |
def load_image_file(image_path): | |
if not os.path.exists(image_path): | |
raise FileNotFoundError(f"Image file not found: {image_path}") | |
image = torchvision.io.read_image(image_path) | |
return image | |
def process_sample(sample_path, save_dir): | |
# Recursively search for audio and image files | |
audio_files = [] | |
image_files = [] | |
# Walk through all subdirectories | |
for root, _, files in os.walk(sample_path): | |
for file in files: | |
if file.lower().endswith(('.wav', '.mp3', '.flac')): | |
audio_files.append(os.path.join(root, file)) | |
elif file.lower().endswith(('.jpg', '.jpeg', '.png')): | |
image_files.append(os.path.join(root, file)) | |
if not audio_files: | |
print(f"Warning: No audio file found in {sample_path}. Skipping this sample.") | |
return | |
if not image_files: | |
print(f"Warning: No image file found in {sample_path}. Skipping this sample.") | |
return | |
# Use the first found audio and image files | |
audio_path = audio_files[0] | |
image_path = image_files[0] | |
print(f"Processing audio: {audio_path}") | |
print(f"Processing image: {image_path}") | |
waveform, sample_rate = load_audio_file(audio_path) | |
image = load_image_file(image_path) | |
# Process data | |
mfcc = process_audio_data(waveform, sample_rate) | |
processed_image = process_image_data(image) | |
# Save processed data | |
save_path = os.path.join(save_dir, f"{os.path.basename(sample_path)}.pt") | |
torch.save((mfcc, processed_image, float(os.path.basename(sample_path))), save_path) | |
print(f"Processed and saved: {save_path}") | |
def process_and_save(data_dir, save_dir): | |
os.makedirs(save_dir, exist_ok=True) | |
sample_paths = [os.path.join(data_dir, d) for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))] | |
if not sample_paths: | |
print(f"Warning: No sample directories found in {data_dir}") | |
return | |
print(f"Found {len(sample_paths)} sample directories to process") | |
successful = 0 | |
failed = 0 | |
with ThreadPoolExecutor() as executor: | |
futures = [executor.submit(process_sample, path, save_dir) for path in sample_paths] | |
for future in futures: | |
try: | |
future.result() # Wait for all threads to complete | |
successful += 1 | |
except Exception as e: | |
failed += 1 | |
print(f"Error processing a sample: {e}") | |
print(f"Processing complete. Successfully processed: {successful}, Failed: {failed}") | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser(description="Preprocess the dataset") | |
parser.add_argument( | |
"--data_dir", | |
type=str, | |
default="cleaned", | |
help="Path to the cleaned dataset directory", | |
) | |
parser.add_argument( | |
"--save_dir", | |
type=str, | |
default="processed", | |
help="Path to the processed dataset directory", | |
) | |
args = parser.parse_args() | |
print(f"Processing dataset from: {args.data_dir}") | |
print(f"Saving processed data to: {args.save_dir}") | |
process_and_save(args.data_dir, args.save_dir) | |
print("Preprocessing complete") | |