watermelon / preprocess_file.py
Xalphinions's picture
Upload folder using huggingface_hub
fdc673b verified
raw
history blame
5.08 kB
import os
import glob
import torch
import torchaudio
import torchvision
from torch.utils.data import Dataset
from concurrent.futures import ThreadPoolExecutor
from preprocess import process_audio_data, process_image_data, resample_rate
class PreprocessedDataset(Dataset):
def __init__(self, data_dir):
self.data_dir = data_dir
self.samples = [
os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".pt")
]
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample_path = self.samples[idx]
mfcc, image, label = torch.load(sample_path)
# Process data
mfcc = process_audio_data(mfcc, resample_rate)
image = process_image_data(image)
return mfcc, image, label
def load_audio_file(audio_path):
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")
try:
# Try the default torchaudio loader first
waveform, sample_rate = torchaudio.load(audio_path)
except Exception as e:
print(f"Warning: Could not load {audio_path} with torchaudio: {e}")
# Fall back to librosa (you'll need to install it: pip install librosa)
try:
import librosa
import numpy as np
waveform_np, sample_rate = librosa.load(audio_path, sr=None)
# Convert to torch tensor with shape [1, length] to match torchaudio format
waveform = torch.from_numpy(waveform_np[np.newaxis, :]).float()
print(f"Successfully loaded with librosa: {audio_path}")
except Exception as final_e:
raise RuntimeError(f"Failed to load audio file {audio_path} with all available methods: {final_e}")
return waveform, sample_rate
def load_image_file(image_path):
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image file not found: {image_path}")
image = torchvision.io.read_image(image_path)
return image
def process_sample(sample_path, save_dir):
# Recursively search for audio and image files
audio_files = []
image_files = []
# Walk through all subdirectories
for root, _, files in os.walk(sample_path):
for file in files:
if file.lower().endswith(('.wav', '.mp3', '.flac')):
audio_files.append(os.path.join(root, file))
elif file.lower().endswith(('.jpg', '.jpeg', '.png')):
image_files.append(os.path.join(root, file))
if not audio_files:
print(f"Warning: No audio file found in {sample_path}. Skipping this sample.")
return
if not image_files:
print(f"Warning: No image file found in {sample_path}. Skipping this sample.")
return
# Use the first found audio and image files
audio_path = audio_files[0]
image_path = image_files[0]
print(f"Processing audio: {audio_path}")
print(f"Processing image: {image_path}")
waveform, sample_rate = load_audio_file(audio_path)
image = load_image_file(image_path)
# Process data
mfcc = process_audio_data(waveform, sample_rate)
processed_image = process_image_data(image)
# Save processed data
save_path = os.path.join(save_dir, f"{os.path.basename(sample_path)}.pt")
torch.save((mfcc, processed_image, float(os.path.basename(sample_path))), save_path)
print(f"Processed and saved: {save_path}")
def process_and_save(data_dir, save_dir):
os.makedirs(save_dir, exist_ok=True)
sample_paths = [os.path.join(data_dir, d) for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]
if not sample_paths:
print(f"Warning: No sample directories found in {data_dir}")
return
print(f"Found {len(sample_paths)} sample directories to process")
successful = 0
failed = 0
with ThreadPoolExecutor() as executor:
futures = [executor.submit(process_sample, path, save_dir) for path in sample_paths]
for future in futures:
try:
future.result() # Wait for all threads to complete
successful += 1
except Exception as e:
failed += 1
print(f"Error processing a sample: {e}")
print(f"Processing complete. Successfully processed: {successful}, Failed: {failed}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Preprocess the dataset")
parser.add_argument(
"--data_dir",
type=str,
default="cleaned",
help="Path to the cleaned dataset directory",
)
parser.add_argument(
"--save_dir",
type=str,
default="processed",
help="Path to the processed dataset directory",
)
args = parser.parse_args()
print(f"Processing dataset from: {args.data_dir}")
print(f"Saving processed data to: {args.save_dir}")
process_and_save(args.data_dir, args.save_dir)
print("Preprocessing complete")