Spaces:
Sleeping
Sleeping
File size: 5,083 Bytes
fdc673b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import os
import glob
import torch
import torchaudio
import torchvision
from torch.utils.data import Dataset
from concurrent.futures import ThreadPoolExecutor
from preprocess import process_audio_data, process_image_data, resample_rate
class PreprocessedDataset(Dataset):
def __init__(self, data_dir):
self.data_dir = data_dir
self.samples = [
os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".pt")
]
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample_path = self.samples[idx]
mfcc, image, label = torch.load(sample_path)
# Process data
mfcc = process_audio_data(mfcc, resample_rate)
image = process_image_data(image)
return mfcc, image, label
def load_audio_file(audio_path):
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")
try:
# Try the default torchaudio loader first
waveform, sample_rate = torchaudio.load(audio_path)
except Exception as e:
print(f"Warning: Could not load {audio_path} with torchaudio: {e}")
# Fall back to librosa (you'll need to install it: pip install librosa)
try:
import librosa
import numpy as np
waveform_np, sample_rate = librosa.load(audio_path, sr=None)
# Convert to torch tensor with shape [1, length] to match torchaudio format
waveform = torch.from_numpy(waveform_np[np.newaxis, :]).float()
print(f"Successfully loaded with librosa: {audio_path}")
except Exception as final_e:
raise RuntimeError(f"Failed to load audio file {audio_path} with all available methods: {final_e}")
return waveform, sample_rate
def load_image_file(image_path):
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image file not found: {image_path}")
image = torchvision.io.read_image(image_path)
return image
def process_sample(sample_path, save_dir):
# Recursively search for audio and image files
audio_files = []
image_files = []
# Walk through all subdirectories
for root, _, files in os.walk(sample_path):
for file in files:
if file.lower().endswith(('.wav', '.mp3', '.flac')):
audio_files.append(os.path.join(root, file))
elif file.lower().endswith(('.jpg', '.jpeg', '.png')):
image_files.append(os.path.join(root, file))
if not audio_files:
print(f"Warning: No audio file found in {sample_path}. Skipping this sample.")
return
if not image_files:
print(f"Warning: No image file found in {sample_path}. Skipping this sample.")
return
# Use the first found audio and image files
audio_path = audio_files[0]
image_path = image_files[0]
print(f"Processing audio: {audio_path}")
print(f"Processing image: {image_path}")
waveform, sample_rate = load_audio_file(audio_path)
image = load_image_file(image_path)
# Process data
mfcc = process_audio_data(waveform, sample_rate)
processed_image = process_image_data(image)
# Save processed data
save_path = os.path.join(save_dir, f"{os.path.basename(sample_path)}.pt")
torch.save((mfcc, processed_image, float(os.path.basename(sample_path))), save_path)
print(f"Processed and saved: {save_path}")
def process_and_save(data_dir, save_dir):
os.makedirs(save_dir, exist_ok=True)
sample_paths = [os.path.join(data_dir, d) for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]
if not sample_paths:
print(f"Warning: No sample directories found in {data_dir}")
return
print(f"Found {len(sample_paths)} sample directories to process")
successful = 0
failed = 0
with ThreadPoolExecutor() as executor:
futures = [executor.submit(process_sample, path, save_dir) for path in sample_paths]
for future in futures:
try:
future.result() # Wait for all threads to complete
successful += 1
except Exception as e:
failed += 1
print(f"Error processing a sample: {e}")
print(f"Processing complete. Successfully processed: {successful}, Failed: {failed}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Preprocess the dataset")
parser.add_argument(
"--data_dir",
type=str,
default="cleaned",
help="Path to the cleaned dataset directory",
)
parser.add_argument(
"--save_dir",
type=str,
default="processed",
help="Path to the processed dataset directory",
)
args = parser.parse_args()
print(f"Processing dataset from: {args.data_dir}")
print(f"Saving processed data to: {args.save_dir}")
process_and_save(args.data_dir, args.save_dir)
print("Preprocessing complete")
|