File size: 5,083 Bytes
fdc673b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import glob
import torch
import torchaudio
import torchvision
from torch.utils.data import Dataset
from concurrent.futures import ThreadPoolExecutor
from preprocess import process_audio_data, process_image_data, resample_rate

class PreprocessedDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.samples = [
            os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".pt")
        ]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample_path = self.samples[idx]
        mfcc, image, label = torch.load(sample_path)

        # Process data
        mfcc = process_audio_data(mfcc, resample_rate)
        image = process_image_data(image)

        return mfcc, image, label

def load_audio_file(audio_path):
    if not os.path.exists(audio_path):
        raise FileNotFoundError(f"Audio file not found: {audio_path}")

    try:
        # Try the default torchaudio loader first
        waveform, sample_rate = torchaudio.load(audio_path)
    except Exception as e:
        print(f"Warning: Could not load {audio_path} with torchaudio: {e}")
        
        # Fall back to librosa (you'll need to install it: pip install librosa)
        try:
            import librosa
            import numpy as np
            
            waveform_np, sample_rate = librosa.load(audio_path, sr=None)
            # Convert to torch tensor with shape [1, length] to match torchaudio format
            waveform = torch.from_numpy(waveform_np[np.newaxis, :]).float()
            print(f"Successfully loaded with librosa: {audio_path}")
        except Exception as final_e:
            raise RuntimeError(f"Failed to load audio file {audio_path} with all available methods: {final_e}")
    
    return waveform, sample_rate

def load_image_file(image_path):
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image file not found: {image_path}")

    image = torchvision.io.read_image(image_path)
    return image

def process_sample(sample_path, save_dir):
    # Recursively search for audio and image files
    audio_files = []
    image_files = []
    
    # Walk through all subdirectories
    for root, _, files in os.walk(sample_path):
        for file in files:
            if file.lower().endswith(('.wav', '.mp3', '.flac')):
                audio_files.append(os.path.join(root, file))
            elif file.lower().endswith(('.jpg', '.jpeg', '.png')):
                image_files.append(os.path.join(root, file))
    
    if not audio_files:
        print(f"Warning: No audio file found in {sample_path}. Skipping this sample.")
        return
    
    if not image_files:
        print(f"Warning: No image file found in {sample_path}. Skipping this sample.")
        return
    
    # Use the first found audio and image files
    audio_path = audio_files[0]
    image_path = image_files[0]
    
    print(f"Processing audio: {audio_path}")
    print(f"Processing image: {image_path}")

    waveform, sample_rate = load_audio_file(audio_path)
    image = load_image_file(image_path)

    # Process data
    mfcc = process_audio_data(waveform, sample_rate)
    processed_image = process_image_data(image)

    # Save processed data
    save_path = os.path.join(save_dir, f"{os.path.basename(sample_path)}.pt")
    torch.save((mfcc, processed_image, float(os.path.basename(sample_path))), save_path)
    print(f"Processed and saved: {save_path}")

def process_and_save(data_dir, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    sample_paths = [os.path.join(data_dir, d) for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]
    
    if not sample_paths:
        print(f"Warning: No sample directories found in {data_dir}")
        return

    print(f"Found {len(sample_paths)} sample directories to process")
    
    successful = 0
    failed = 0

    with ThreadPoolExecutor() as executor:
        futures = [executor.submit(process_sample, path, save_dir) for path in sample_paths]
        for future in futures:
            try:
                future.result()  # Wait for all threads to complete
                successful += 1
            except Exception as e:
                failed += 1
                print(f"Error processing a sample: {e}")
    
    print(f"Processing complete. Successfully processed: {successful}, Failed: {failed}")

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Preprocess the dataset")
    parser.add_argument(
        "--data_dir",
        type=str,
        default="cleaned",
        help="Path to the cleaned dataset directory",
    )
    parser.add_argument(
        "--save_dir",
        type=str,
        default="processed",
        help="Path to the processed dataset directory",
    )
    args = parser.parse_args()

    print(f"Processing dataset from: {args.data_dir}")
    print(f"Saving processed data to: {args.save_dir}")

    process_and_save(args.data_dir, args.save_dir)

    print("Preprocessing complete")