Xalphinions commited on
Commit
fdc673b
·
verified ·
1 Parent(s): 8711293

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ temp/temp_image.jpg filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Watermelon
3
- emoji: 👁
4
- colorFrom: pink
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.23.3
8
  app_file: app.py
9
- pinned: false
 
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: watermelon
 
 
 
 
 
3
  app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 4.44.1
6
  ---
 
 
__pycache__/infer.cpython-39.pyc ADDED
Binary file (1.48 kB). View file
 
__pycache__/preprocess.cpython-39.pyc ADDED
Binary file (1.27 kB). View file
 
__pycache__/train.cpython-39.pyc ADDED
Binary file (7.13 kB). View file
 
__pycache__/train_2.cpython-39.pyc ADDED
Binary file (10.7 kB). View file
 
app.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, torchaudio, torchvision
2
+ import os
3
+ import gradio as gr
4
+ import numpy as np
5
+
6
+ from preprocess import process_audio_data, process_image_data
7
+ from train import WatermelonModel
8
+ from infer import infer
9
+
10
+ def load_model(model_path):
11
+ global device
12
+ device = torch.device(
13
+ "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
14
+ )
15
+ print(f"\033[92mINFO\033[0m: Using device: {device}")
16
+
17
+ # Check if the file exists
18
+ if not os.path.exists(model_path):
19
+ raise FileNotFoundError(f"Model file not found at {model_path}")
20
+
21
+ # Check if the file is empty or very small
22
+ file_size = os.path.getsize(model_path)
23
+ if file_size < 1000: # Less than 1KB is suspiciously small for a model
24
+ print(f"\033[93mWARNING\033[0m: Model file size is only {file_size} bytes, which is suspiciously small")
25
+
26
+ try:
27
+ model = WatermelonModel().to(device)
28
+ model.load_state_dict(torch.load(model_path, map_location=device))
29
+ model.eval()
30
+ print(f"\033[92mINFO\033[0m: Loaded model from {model_path}")
31
+ return model
32
+ except RuntimeError as e:
33
+ if "failed finding central directory" in str(e):
34
+ print(f"\033[91mERROR\033[0m: The model file at {model_path} appears to be corrupted.")
35
+ print("This can happen if:")
36
+ print(" 1. The model saving process was interrupted")
37
+ print(" 2. The file was not properly downloaded")
38
+ print(" 3. The path points to a file that is not a valid PyTorch model")
39
+ print(f"File size: {file_size} bytes")
40
+ raise
41
+
42
+ if __name__ == "__main__":
43
+ import argparse
44
+
45
+ parser = argparse.ArgumentParser(description="Watermelon sweetness predictor")
46
+ parser.add_argument("--model_path", type=str, required=True, help="Path to the trained model")
47
+ args = parser.parse_args()
48
+
49
+ model = load_model(args.model_path)
50
+
51
+ def predict(audio, image):
52
+ try:
53
+ # Debug audio input
54
+ print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
55
+ print(f"\033[92mDEBUG\033[0m: Audio input value: {audio}")
56
+
57
+ # Handle different formats of audio input from Gradio
58
+ if audio is None:
59
+ return "Error: No audio provided. Please upload or record audio."
60
+
61
+ if isinstance(audio, tuple) and len(audio) >= 2:
62
+ sr, audio_data = audio[0], audio[-1]
63
+ print(f"\033[92mDEBUG\033[0m: Audio format: sr={sr}, audio_data shape={audio_data.shape if hasattr(audio_data, 'shape') else 'no shape'}")
64
+ elif isinstance(audio, tuple) and len(audio) == 1:
65
+ # Handle single element tuple
66
+ audio_data = audio[0]
67
+ sr = 44100 # Assume default sample rate
68
+ print(f"\033[92mDEBUG\033[0m: Single element audio tuple, using default sr={sr}")
69
+ elif isinstance(audio, np.ndarray):
70
+ # Handle direct numpy array
71
+ audio_data = audio
72
+ sr = 44100 # Assume default sample rate
73
+ print(f"\033[92mDEBUG\033[0m: Audio is numpy array, using default sr={sr}")
74
+ else:
75
+ return f"Error: Unexpected audio format: {type(audio)}"
76
+
77
+ # Ensure audio_data is correctly shaped
78
+ if isinstance(audio_data, np.ndarray):
79
+ # Make sure we have a 2D array
80
+ if len(audio_data.shape) == 1:
81
+ audio_data = np.expand_dims(audio_data, axis=0)
82
+ print(f"\033[92mDEBUG\033[0m: Reshaped 1D audio to 2D: {audio_data.shape}")
83
+
84
+ # If channels are the second dimension, transpose
85
+ if len(audio_data.shape) == 2 and audio_data.shape[0] > audio_data.shape[1]:
86
+ audio_data = np.transpose(audio_data)
87
+ print(f"\033[92mDEBUG\033[0m: Transposed audio shape to: {audio_data.shape}")
88
+
89
+ # Convert to tensor
90
+ audio_tensor = torch.tensor(audio_data).float()
91
+ print(f"\033[92mDEBUG\033[0m: Audio tensor shape: {audio_tensor.shape}")
92
+
93
+ # Process audio data and handle None case
94
+ mfcc = process_audio_data(audio_tensor, sr)
95
+ if mfcc is None:
96
+ return "Error: Failed to process audio data. Make sure your audio contains a clear tapping sound."
97
+
98
+ mfcc = mfcc.to(device)
99
+ print(f"\033[92mDEBUG\033[0m: MFCC shape: {mfcc.shape}")
100
+
101
+ # Debug image input
102
+ print(f"\033[92mDEBUG\033[0m: Image input type: {type(image)}")
103
+ print(f"\033[92mDEBUG\033[0m: Image shape: {image.shape if hasattr(image, 'shape') else 'No shape'}")
104
+
105
+ # Process image data and handle None case
106
+ if image is None:
107
+ return "Error: No image provided. Please upload an image."
108
+
109
+ # Handle different image formats
110
+ if isinstance(image, np.ndarray):
111
+ # Check if image is properly formatted (H, W, C) with 3 channels
112
+ if len(image.shape) == 3 and image.shape[2] == 3:
113
+ # Convert to tensor with shape (C, H, W) as expected by PyTorch
114
+ img = torch.tensor(image).float().permute(2, 0, 1)
115
+ print(f"\033[92mDEBUG\033[0m: Converted image to tensor with shape: {img.shape}")
116
+ elif len(image.shape) == 2:
117
+ # Grayscale image, expand to 3 channels
118
+ img = torch.tensor(image).float().unsqueeze(0).repeat(3, 1, 1)
119
+ print(f"\033[92mDEBUG\033[0m: Converted grayscale image to RGB tensor with shape: {img.shape}")
120
+ else:
121
+ return f"Error: Unexpected image shape: {image.shape}. Expected RGB or grayscale image."
122
+ else:
123
+ return f"Error: Unexpected image format: {type(image)}. Expected numpy array."
124
+
125
+ # Scale pixel values to [0, 1] if needed
126
+ if img.max() > 1.0:
127
+ img = img / 255.0
128
+ print(f"\033[92mDEBUG\033[0m: Scaled image pixel values to range [0, 1]")
129
+
130
+ # Get image dimensions and check if they're reasonable
131
+ print(f"\033[92mDEBUG\033[0m: Final image tensor shape before processing: {img.shape}")
132
+
133
+ # Process image
134
+ try:
135
+ img_processed = process_image_data(img)
136
+ if img_processed is None:
137
+ return "Error: Failed to process image data. Make sure your image clearly shows a watermelon."
138
+
139
+ img_processed = img_processed.to(device)
140
+ print(f"\033[92mDEBUG\033[0m: Processed image shape: {img_processed.shape}")
141
+ except Exception as e:
142
+ print(f"\033[91mERROR\033[0m: Image processing error: {str(e)}")
143
+ return f"Error in image processing: {str(e)}"
144
+
145
+ # Run inference
146
+ try:
147
+ # Based on the error, it seems infer() expects file paths, not tensors
148
+ # Let's create temporary files for the processed data
149
+ temp_dir = os.path.join(os.getcwd(), "temp")
150
+ os.makedirs(temp_dir, exist_ok=True)
151
+
152
+ # Save the audio to a temporary file if infer expects a file path
153
+ temp_audio_path = os.path.join(temp_dir, "temp_audio.wav")
154
+ if not isinstance(audio, str) and isinstance(audio, tuple) and len(audio) >= 2:
155
+ # If we have the original audio data and sample rate
156
+ audio_array = audio[-1]
157
+ sr = audio[0]
158
+
159
+ # Check if the audio array is valid
160
+ if audio_array.size == 0:
161
+ return "Error: Audio data is empty. Please record a longer audio clip."
162
+
163
+ # Get the duration of the audio
164
+ duration = audio_array.shape[-1] / sr
165
+ print(f"\033[92mDEBUG\033[0m: Audio duration: {duration:.2f} seconds")
166
+
167
+ # Check if we have at least 1 second of audio - but don't reject, just pad if needed
168
+ min_duration = 1.0 # minimum 1 second of audio
169
+ if duration < min_duration:
170
+ print(f"\033[93mWARNING\033[0m: Audio is shorter than {min_duration} seconds. Padding will be applied.")
171
+ # Calculate samples needed to reach minimum duration
172
+ samples_needed = int(min_duration * sr) - audio_array.shape[-1]
173
+ # Pad with zeros
174
+ padding = np.zeros((audio_array.shape[0], samples_needed), dtype=audio_array.dtype)
175
+ audio_array = np.concatenate([audio_array, padding], axis=1)
176
+ print(f"\033[92mDEBUG\033[0m: Padded audio to shape: {audio_array.shape}")
177
+
178
+ # Make sure audio has 2 dimensions
179
+ if len(audio_array.shape) == 1:
180
+ audio_array = np.expand_dims(audio_array, axis=0)
181
+
182
+ print(f"\033[92mDEBUG\033[0m: Audio array shape before saving: {audio_array.shape}, sr: {sr}")
183
+
184
+ # Make sure it's in the right format for torchaudio.save
185
+ audio_tensor = torch.tensor(audio_array).float()
186
+ if audio_tensor.dim() == 1:
187
+ audio_tensor = audio_tensor.unsqueeze(0)
188
+
189
+ torchaudio.save(temp_audio_path, audio_tensor, sr)
190
+ print(f"\033[92mDEBUG\033[0m: Saved temporary audio file to {temp_audio_path}")
191
+
192
+ # Let's also process the audio here to verify it works
193
+ test_mfcc = process_audio_data(audio_tensor, sr)
194
+ if test_mfcc is None:
195
+ return "Error: Unable to process the audio. Please try recording a different audio sample."
196
+ else:
197
+ print(f"\033[92mDEBUG\033[0m: Audio pre-check passed. MFCC shape: {test_mfcc.shape}")
198
+
199
+ audio_path = temp_audio_path
200
+ else:
201
+ # If we don't have a valid path, return an error
202
+ return "Error: Cannot process audio for inference. Invalid audio format."
203
+
204
+ # Save the image to a temporary file if infer expects a file path
205
+ temp_image_path = os.path.join(temp_dir, "temp_image.jpg")
206
+ if isinstance(image, np.ndarray):
207
+ import cv2
208
+ cv2.imwrite(temp_image_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
209
+ print(f"\033[92mDEBUG\033[0m: Saved temporary image file to {temp_image_path}")
210
+ image_path = temp_image_path
211
+ else:
212
+ # If we don't have a valid image, return an error
213
+ return "Error: Cannot process image for inference. Invalid image format."
214
+
215
+ # Create a modified version of infer that handles None returns
216
+ def safe_infer(audio_path, image_path, model, device):
217
+ try:
218
+ return infer(audio_path, image_path, model, device)
219
+ except Exception as e:
220
+ print(f"\033[91mERROR\033[0m: Error in infer function: {str(e)}")
221
+ # Try a more direct approach
222
+ try:
223
+ # Load audio and process
224
+ audio, sr = torchaudio.load(audio_path)
225
+ mfcc = process_audio_data(audio, sr)
226
+ if mfcc is None:
227
+ raise ValueError("Audio processing failed - MFCC is None")
228
+ mfcc = mfcc.to(device)
229
+
230
+ # Load image and process
231
+ image = cv2.imread(image_path)
232
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
233
+ image_tensor = torch.tensor(image).float().permute(2, 0, 1) / 255.0
234
+ img_processed = process_image_data(image_tensor)
235
+ if img_processed is None:
236
+ raise ValueError("Image processing failed - processed image is None")
237
+ img_processed = img_processed.to(device)
238
+
239
+ # Run model inference
240
+ with torch.no_grad():
241
+ prediction = model(mfcc, img_processed)
242
+ return prediction
243
+ except Exception as e2:
244
+ print(f"\033[91mERROR\033[0m: Fallback inference also failed: {str(e2)}")
245
+ raise
246
+
247
+ # Call our safer version
248
+ print(f"\033[92mDEBUG\033[0m: Calling safe_infer with audio_path={audio_path}, image_path={image_path}")
249
+ sweetness = safe_infer(audio_path, image_path, model, device)
250
+ if sweetness is None:
251
+ return "Error: The model was unable to make a prediction. Please try with different inputs."
252
+
253
+ print(f"\033[92mDEBUG\033[0m: Inference result: {sweetness.item()}")
254
+ return f"Predicted Sweetness: {sweetness.item():.2f}/10"
255
+ except Exception as e:
256
+ import traceback
257
+ print(f"\033[91mERROR\033[0m: Inference failed: {str(e)}")
258
+ print(f"\033[91mTraceback\033[0m: {traceback.format_exc()}")
259
+ return f"Error during inference: {str(e)}"
260
+
261
+ except Exception as e:
262
+ import traceback
263
+ print(f"\033[91mERROR\033[0m: Prediction failed: {str(e)}")
264
+ print(f"\033[91mTraceback\033[0m: {traceback.format_exc()}")
265
+ return f"Error processing input: {str(e)}"
266
+
267
+ audio_input = gr.Audio(label="Upload or Record Audio")
268
+ image_input = gr.Image(label="Upload or Capture Image")
269
+ output = gr.Textbox(label="Predicted Sweetness")
270
+
271
+ interface = gr.Interface(
272
+ fn=predict,
273
+ inputs=[audio_input, image_input],
274
+ outputs=output,
275
+ title="Watermelon Sweetness Predictor",
276
+ description="Upload an audio file and an image to predict the sweetness of a watermelon."
277
+ )
278
+
279
+ try:
280
+ interface.launch(share=True) # Enable sharing to avoid localhost access issues
281
+ except Exception as e:
282
+ print(f"\033[91mERROR\033[0m: Failed to launch interface: {e}")
283
+ print("\033[93mTIP\033[0m: If you're running in a remote environment or container, try setting additional parameters:")
284
+ print(" interface.launch(server_name='0.0.0.0', share=True)")
infer.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, torchaudio
2
+ import argparse
3
+ from preprocess import process_audio_data, process_image_data
4
+ from train_2 import WatermelonModel
5
+
6
+
7
+ def infer(audio, image, model, device):
8
+ # Load and preprocess the input data
9
+ audio, sr = torchaudio.load(audio)
10
+ mfcc = process_audio_data(audio, sr).to(device)
11
+ img = process_image_data(image).to(device)
12
+ if mfcc is None or img is None:
13
+ return None
14
+
15
+ # Run inference
16
+ with torch.no_grad():
17
+ predicted_sweetness = model(mfcc, img).item()
18
+
19
+ return predicted_sweetness
20
+
21
+
22
+ if __name__ == "__main__":
23
+ parser = argparse.ArgumentParser(description="Run Watermelon Sweetness Prediction")
24
+ parser.add_argument(
25
+ "--model_path", type=str, required=True, help="Path to the saved model file"
26
+ )
27
+ parser.add_argument(
28
+ "--audio_path", type=str, required=True, help="Path to audio file"
29
+ )
30
+ parser.add_argument(
31
+ "--image_path", type=str, required=True, help="Path to image file"
32
+ )
33
+ args = parser.parse_args()
34
+
35
+ # Initialize the model and device
36
+ print(f"\033[92mINFO\033[0m: PyTorch version: {torch.__version__}")
37
+ device = torch.device(
38
+ "cuda"
39
+ if torch.cuda.is_available()
40
+ else "mps" if torch.backends.mps.is_available() else "cpu"
41
+ )
42
+ print(f"\033[92mINFO\033[0m: Using device: {device}")
43
+ model = WatermelonModel().to(device)
44
+ model.load_state_dict(torch.load(args.model_path, map_location=device))
45
+
46
+ # Example paths to audio and image files
47
+ audio_path = args.audio_patb
48
+ image_path = args.image_path
49
+
50
+ # Run inference
51
+ sweetness = infer(audio_path, image_path, model, device)
52
+ print(f"Predicted sweetness: {sweetness}")
preprocess.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import torchvision
4
+
5
+ resample_rate = 16000
6
+
7
+ def process_audio_data(waveform, sample_rate):
8
+ try:
9
+ waveform = waveform[0] # 使用左声道
10
+ waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=resample_rate)(waveform)
11
+
12
+ if waveform.size(0) < 3 * resample_rate:
13
+ waveform = torch.nn.functional.pad(waveform, (0, 3 * resample_rate - waveform.size(0)))
14
+ else:
15
+ waveform = waveform[: 3 * resample_rate]
16
+
17
+ mfcc = torchaudio.transforms.MFCC(
18
+ sample_rate=resample_rate,
19
+ n_mfcc=13,
20
+ melkwargs={
21
+ "n_fft": 256,
22
+ "win_length": 256,
23
+ "hop_length": 128,
24
+ "n_mels": 40,
25
+ }
26
+ )(waveform)
27
+
28
+ return mfcc
29
+ except Exception as e:
30
+ print(f"ERR!: Error in audio processing: {e}")
31
+ return None
32
+
33
+ def process_image_data(image):
34
+ try:
35
+ image = torchvision.transforms.Resize((1080, 1080))(image)
36
+ image = image / 255.0
37
+ image = torchvision.transforms.Normalize(
38
+ mean=[0.485, 0.456, 0.406],
39
+ std=[0.229, 0.224, 0.225]
40
+ )(image)
41
+
42
+ return image
43
+ except Exception as e:
44
+ print(f"ERR!: Error in image processing: {e}")
45
+ return None
preprocess_file.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import torch
4
+ import torchaudio
5
+ import torchvision
6
+ from torch.utils.data import Dataset
7
+ from concurrent.futures import ThreadPoolExecutor
8
+ from preprocess import process_audio_data, process_image_data, resample_rate
9
+
10
+ class PreprocessedDataset(Dataset):
11
+ def __init__(self, data_dir):
12
+ self.data_dir = data_dir
13
+ self.samples = [
14
+ os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".pt")
15
+ ]
16
+
17
+ def __len__(self):
18
+ return len(self.samples)
19
+
20
+ def __getitem__(self, idx):
21
+ sample_path = self.samples[idx]
22
+ mfcc, image, label = torch.load(sample_path)
23
+
24
+ # Process data
25
+ mfcc = process_audio_data(mfcc, resample_rate)
26
+ image = process_image_data(image)
27
+
28
+ return mfcc, image, label
29
+
30
+ def load_audio_file(audio_path):
31
+ if not os.path.exists(audio_path):
32
+ raise FileNotFoundError(f"Audio file not found: {audio_path}")
33
+
34
+ try:
35
+ # Try the default torchaudio loader first
36
+ waveform, sample_rate = torchaudio.load(audio_path)
37
+ except Exception as e:
38
+ print(f"Warning: Could not load {audio_path} with torchaudio: {e}")
39
+
40
+ # Fall back to librosa (you'll need to install it: pip install librosa)
41
+ try:
42
+ import librosa
43
+ import numpy as np
44
+
45
+ waveform_np, sample_rate = librosa.load(audio_path, sr=None)
46
+ # Convert to torch tensor with shape [1, length] to match torchaudio format
47
+ waveform = torch.from_numpy(waveform_np[np.newaxis, :]).float()
48
+ print(f"Successfully loaded with librosa: {audio_path}")
49
+ except Exception as final_e:
50
+ raise RuntimeError(f"Failed to load audio file {audio_path} with all available methods: {final_e}")
51
+
52
+ return waveform, sample_rate
53
+
54
+ def load_image_file(image_path):
55
+ if not os.path.exists(image_path):
56
+ raise FileNotFoundError(f"Image file not found: {image_path}")
57
+
58
+ image = torchvision.io.read_image(image_path)
59
+ return image
60
+
61
+ def process_sample(sample_path, save_dir):
62
+ # Recursively search for audio and image files
63
+ audio_files = []
64
+ image_files = []
65
+
66
+ # Walk through all subdirectories
67
+ for root, _, files in os.walk(sample_path):
68
+ for file in files:
69
+ if file.lower().endswith(('.wav', '.mp3', '.flac')):
70
+ audio_files.append(os.path.join(root, file))
71
+ elif file.lower().endswith(('.jpg', '.jpeg', '.png')):
72
+ image_files.append(os.path.join(root, file))
73
+
74
+ if not audio_files:
75
+ print(f"Warning: No audio file found in {sample_path}. Skipping this sample.")
76
+ return
77
+
78
+ if not image_files:
79
+ print(f"Warning: No image file found in {sample_path}. Skipping this sample.")
80
+ return
81
+
82
+ # Use the first found audio and image files
83
+ audio_path = audio_files[0]
84
+ image_path = image_files[0]
85
+
86
+ print(f"Processing audio: {audio_path}")
87
+ print(f"Processing image: {image_path}")
88
+
89
+ waveform, sample_rate = load_audio_file(audio_path)
90
+ image = load_image_file(image_path)
91
+
92
+ # Process data
93
+ mfcc = process_audio_data(waveform, sample_rate)
94
+ processed_image = process_image_data(image)
95
+
96
+ # Save processed data
97
+ save_path = os.path.join(save_dir, f"{os.path.basename(sample_path)}.pt")
98
+ torch.save((mfcc, processed_image, float(os.path.basename(sample_path))), save_path)
99
+ print(f"Processed and saved: {save_path}")
100
+
101
+ def process_and_save(data_dir, save_dir):
102
+ os.makedirs(save_dir, exist_ok=True)
103
+ sample_paths = [os.path.join(data_dir, d) for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]
104
+
105
+ if not sample_paths:
106
+ print(f"Warning: No sample directories found in {data_dir}")
107
+ return
108
+
109
+ print(f"Found {len(sample_paths)} sample directories to process")
110
+
111
+ successful = 0
112
+ failed = 0
113
+
114
+ with ThreadPoolExecutor() as executor:
115
+ futures = [executor.submit(process_sample, path, save_dir) for path in sample_paths]
116
+ for future in futures:
117
+ try:
118
+ future.result() # Wait for all threads to complete
119
+ successful += 1
120
+ except Exception as e:
121
+ failed += 1
122
+ print(f"Error processing a sample: {e}")
123
+
124
+ print(f"Processing complete. Successfully processed: {successful}, Failed: {failed}")
125
+
126
+ if __name__ == "__main__":
127
+ import argparse
128
+
129
+ parser = argparse.ArgumentParser(description="Preprocess the dataset")
130
+ parser.add_argument(
131
+ "--data_dir",
132
+ type=str,
133
+ default="cleaned",
134
+ help="Path to the cleaned dataset directory",
135
+ )
136
+ parser.add_argument(
137
+ "--save_dir",
138
+ type=str,
139
+ default="processed",
140
+ help="Path to the processed dataset directory",
141
+ )
142
+ args = parser.parse_args()
143
+
144
+ print(f"Processing dataset from: {args.data_dir}")
145
+ print(f"Saving processed data to: {args.save_dir}")
146
+
147
+ process_and_save(args.data_dir, args.save_dir)
148
+
149
+ print("Preprocessing complete")
temp/temp_audio.wav ADDED
Binary file (58 Bytes). View file
 
temp/temp_image.jpg ADDED

Git LFS Details

  • SHA256: 4af3d138bad5d27184910f9d38bac40565a57a7bb2f5efc12a7d7e28aa8f126a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.02 MB
train.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ import torch, torchaudio, torchvision
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from torch.utils.tensorboard import SummaryWriter
7
+ import numpy as np
8
+
9
+ # 打印库的版本信息
10
+ print(f"\033[92mINFO\033[0m: PyTorch version: {torch.__version__}")
11
+ print(f"\033[92mINFO\033[0m: Torchaudio version: {torchaudio.__version__}")
12
+ print(f"\033[92mINFO\033[0m: Torchvision version: {torchvision.__version__}")
13
+
14
+ # 设备选择
15
+ device = torch.device(
16
+ "cuda"
17
+ if torch.cuda.is_available()
18
+ else "mps" if torch.backends.mps.is_available() else "cpu"
19
+ )
20
+ print(f"\033[92mINFO\033[0m: Using device: {device}")
21
+
22
+ # 超参数设置
23
+ batch_size = 1
24
+ epochs = 20
25
+
26
+ # 模型保存目录
27
+ os.makedirs("./models/", exist_ok=True)
28
+
29
+
30
+ class PreprocessedDataset(Dataset):
31
+ def __init__(self, data_dir):
32
+ self.data_dir = data_dir
33
+ self.samples = [
34
+ os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".pt")
35
+ ]
36
+
37
+ def __len__(self):
38
+ return len(self.samples)
39
+
40
+ def __getitem__(self, idx):
41
+ sample_path = self.samples[idx]
42
+ mfcc, image, label = torch.load(sample_path)
43
+ return mfcc.float(), image.float(), label
44
+
45
+
46
+ class WatermelonModel(torch.nn.Module):
47
+ def __init__(self):
48
+ super(WatermelonModel, self).__init__()
49
+
50
+ # LSTM for audio features
51
+ self.lstm = torch.nn.LSTM(
52
+ input_size=376, hidden_size=64, num_layers=2, batch_first=True
53
+ )
54
+ self.lstm_fc = torch.nn.Linear(
55
+ 64, 128
56
+ ) # Convert LSTM output to 128-dim for merging
57
+
58
+ # ResNet50 for image features
59
+ self.resnet = torchvision.models.resnet50(pretrained=True)
60
+ self.resnet.fc = torch.nn.Linear(
61
+ self.resnet.fc.in_features, 128
62
+ ) # Convert ResNet output to 128-dim for merging
63
+
64
+ # Fully connected layers for final prediction
65
+ self.fc1 = torch.nn.Linear(256, 64)
66
+ self.fc2 = torch.nn.Linear(64, 1)
67
+ self.relu = torch.nn.ReLU()
68
+
69
+ def forward(self, mfcc, image):
70
+ # LSTM branch
71
+ lstm_output, _ = self.lstm(mfcc)
72
+ lstm_output = lstm_output[:, -1, :] # Use the output of the last time step
73
+ lstm_output = self.lstm_fc(lstm_output)
74
+
75
+ # ResNet branch
76
+ resnet_output = self.resnet(image)
77
+
78
+ # Concatenate LSTM and ResNet outputs
79
+ merged = torch.cat((lstm_output, resnet_output), dim=1)
80
+
81
+ # Fully connected layers
82
+ output = self.relu(self.fc1(merged))
83
+ output = self.fc2(output)
84
+
85
+ return output
86
+
87
+
88
+ def evaluate_model(model, test_loader, criterion):
89
+ model.eval()
90
+ test_loss = 0.0
91
+ mae_sum = 0.0
92
+ all_predictions = []
93
+ all_labels = []
94
+
95
+ # For debugging
96
+ debug_samples = []
97
+
98
+ with torch.no_grad():
99
+ for mfcc, image, label in test_loader:
100
+ mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
101
+ output = model(mfcc, image)
102
+ label = label.view(-1, 1).float()
103
+
104
+ # Store debug samples
105
+ if len(debug_samples) < 5:
106
+ debug_samples.append((output.item(), label.item()))
107
+
108
+ # Calculate MSE loss
109
+ loss = criterion(output, label)
110
+ test_loss += loss.item()
111
+
112
+ # Calculate MAE
113
+ mae = torch.abs(output - label).mean()
114
+ mae_sum += mae.item()
115
+
116
+ # Store predictions and labels for additional analysis
117
+ all_predictions.extend(output.cpu().numpy())
118
+ all_labels.extend(label.cpu().numpy())
119
+
120
+ avg_loss = test_loss / len(test_loader)
121
+ avg_mae = mae_sum / len(test_loader)
122
+
123
+ # Convert to numpy arrays for easier analysis
124
+ all_predictions = np.array(all_predictions).flatten()
125
+ all_labels = np.array(all_labels).flatten()
126
+
127
+ # Print debug samples
128
+ print("\nDEBUG SAMPLES (Prediction, Label):")
129
+ for i, (pred, label) in enumerate(debug_samples):
130
+ print(f"Sample {i+1}: Prediction = {pred:.4f}, Label = {label:.4f}, Difference = {abs(pred-label):.4f}")
131
+
132
+ return avg_loss, avg_mae, all_predictions, all_labels
133
+
134
+
135
+ def train_model():
136
+ # 数据集加载
137
+ data_dir = "./processed/"
138
+ dataset = PreprocessedDataset(data_dir)
139
+ n_samples = len(dataset)
140
+
141
+ # Check label range
142
+ all_labels = []
143
+ for i in range(min(10, len(dataset))):
144
+ _, _, label = dataset[i]
145
+ all_labels.append(label)
146
+
147
+ print("\nLABEL RANGE CHECK:")
148
+ print(f"Sample labels: {all_labels}")
149
+ print(f"Min label: {min(all_labels)}, Max label: {max(all_labels)}")
150
+
151
+ train_size = int(0.7 * n_samples)
152
+ val_size = int(0.2 * n_samples)
153
+ test_size = n_samples - train_size - val_size
154
+
155
+ train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
156
+ dataset, [train_size, val_size, test_size]
157
+ )
158
+
159
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
160
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
161
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
162
+
163
+ model = WatermelonModel().to(device)
164
+
165
+ # 损失函数和优化器
166
+ criterion = torch.nn.MSELoss()
167
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
168
+
169
+ # TensorBoard
170
+ writer = SummaryWriter("runs/")
171
+ global_step = 0
172
+
173
+ print(f"\033[92mINFO\033[0m: Training model for {epochs} epochs")
174
+ print(f"\033[92mINFO\033[0m: Training samples: {len(train_dataset)}")
175
+ print(f"\033[92mINFO\033[0m: Validation samples: {len(val_dataset)}")
176
+ print(f"\033[92mINFO\033[0m: Test samples: {len(test_dataset)}")
177
+ print(f"\033[92mINFO\033[0m: Batch size: {batch_size}")
178
+
179
+ best_val_loss = float('inf')
180
+ best_model_path = None
181
+
182
+ # 训练循环
183
+ for epoch in range(epochs):
184
+ print(f"\033[92mINFO\033[0m: Training epoch ({epoch+1}/{epochs})")
185
+
186
+ model.train()
187
+ running_loss = 0.0
188
+ try:
189
+ for mfcc, image, label in train_loader:
190
+ mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
191
+
192
+ optimizer.zero_grad()
193
+ output = model(mfcc, image)
194
+ label = label.view(-1, 1).float()
195
+ loss = criterion(output, label)
196
+ loss.backward()
197
+ optimizer.step()
198
+
199
+ running_loss += loss.item()
200
+ writer.add_scalar("Training Loss", loss.item(), global_step)
201
+ global_step += 1
202
+ except Exception as e:
203
+ print(f"\033[91mERR!\033[0m: {e}")
204
+
205
+ # 验证阶段
206
+ model.eval()
207
+ val_loss = 0.0
208
+ with torch.no_grad():
209
+ try:
210
+ for mfcc, image, label in val_loader:
211
+ mfcc, image, label = (
212
+ mfcc.to(device),
213
+ image.to(device),
214
+ label.to(device),
215
+ )
216
+ output = model(mfcc, image)
217
+ loss = criterion(output, label.view(-1, 1))
218
+ val_loss += loss.item()
219
+ except Exception as e:
220
+ print(f"\033[91mERR!\033[0m: {e}")
221
+
222
+ avg_val_loss = val_loss / len(val_loader)
223
+
224
+ # 记录验证损失
225
+ writer.add_scalar("Validation Loss", avg_val_loss, epoch)
226
+
227
+ print(
228
+ f"Epoch [{epoch+1}/{epochs}], Training Loss: {running_loss/len(train_loader):.4f}, "
229
+ f"Validation Loss: {avg_val_loss:.4f}"
230
+ )
231
+
232
+ # 保存模型检查点
233
+ timestamp = time.strftime("%Y%m%d-%H%M%S")
234
+ model_path = f"models/model_{epoch+1}_{timestamp}.pt"
235
+ torch.save(model.state_dict(), model_path)
236
+
237
+ # Save the best model based on validation loss
238
+ if avg_val_loss < best_val_loss:
239
+ best_val_loss = avg_val_loss
240
+ best_model_path = model_path
241
+ print(f"\033[92mINFO\033[0m: New best model saved with validation loss: {best_val_loss:.4f}")
242
+
243
+ print(
244
+ f"\033[92mINFO\033[0m: Model checkpoint epoch [{epoch+1}/{epochs}] saved: {model_path}"
245
+ )
246
+
247
+ print(f"\033[92mINFO\033[0m: Training complete")
248
+
249
+ # Load the best model for testing
250
+ print(f"\033[92mINFO\033[0m: Loading best model from {best_model_path} for testing")
251
+ model.load_state_dict(torch.load(best_model_path))
252
+
253
+ # Evaluate on test set
254
+ test_loss, test_mae, predictions, labels = evaluate_model(model, test_loader, criterion)
255
+
256
+ # Calculate additional metrics
257
+ max_error = np.max(np.abs(predictions - labels))
258
+ min_error = np.min(np.abs(predictions - labels))
259
+
260
+ print("\n" + "="*50)
261
+ print("TEST RESULTS:")
262
+ print(f"Test Loss (MSE): {test_loss:.4f}")
263
+ print(f"Mean Absolute Error: {test_mae:.4f}")
264
+ print(f"Maximum Absolute Error: {max_error:.4f}")
265
+ print(f"Minimum Absolute Error: {min_error:.4f}")
266
+
267
+ # Add test results to TensorBoard
268
+ writer.add_scalar("Test/MSE", test_loss, 0)
269
+ writer.add_scalar("Test/MAE", test_mae, 0)
270
+ writer.add_scalar("Test/Max_Error", max_error, 0)
271
+ writer.add_scalar("Test/Min_Error", min_error, 0)
272
+
273
+ # Create a histogram of absolute errors
274
+ abs_errors = np.abs(predictions - labels)
275
+ writer.add_histogram("Test/Absolute_Errors", abs_errors, 0)
276
+
277
+ print("="*50)
278
+
279
+ writer.close()
280
+
281
+
282
+ if __name__ == "__main__":
283
+ train_model()
train_2.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import argparse
4
+ import torch
5
+ import torchaudio
6
+ import torchvision
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from torch.utils.tensorboard import SummaryWriter
9
+ import numpy as np
10
+ from efficient_model import MobileNetGRUModel, EfficientNetCNNModel, SqueezeNetTransformerModel
11
+
12
+ # Print library version information
13
+ print(f"\033[92mINFO\033[0m: PyTorch version: {torch.__version__}")
14
+ print(f"\033[92mINFO\033[0m: Torchaudio version: {torchaudio.__version__}")
15
+ print(f"\033[92mINFO\033[0m: Torchvision version: {torchvision.__version__}")
16
+
17
+ # Device selection
18
+ device = torch.device(
19
+ "cuda"
20
+ if torch.cuda.is_available()
21
+ else "mps" if torch.backends.mps.is_available() else "cpu"
22
+ )
23
+ print(f"\033[92mINFO\033[0m: Using device: {device}")
24
+
25
+ # Hyperparameters (using the best configuration from search)
26
+ batch_size = 4
27
+ epochs = 20
28
+ fc_hidden_size = 64
29
+ learning_rate = 0.0005
30
+ dropout_rate = 0.5
31
+
32
+ # Model save directory
33
+ os.makedirs("./models/", exist_ok=True)
34
+
35
+
36
+ class PreprocessedDataset(Dataset):
37
+ def __init__(self, data_dir):
38
+ self.data_dir = data_dir
39
+ self.samples = [
40
+ os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".pt")
41
+ ]
42
+
43
+ def __len__(self):
44
+ return len(self.samples)
45
+
46
+ def __getitem__(self, idx):
47
+ sample_path = self.samples[idx]
48
+ mfcc, image, label = torch.load(sample_path)
49
+ return mfcc.float(), image.float(), label
50
+
51
+
52
+ def calculate_mae(outputs, labels):
53
+ """Calculate Mean Absolute Error between outputs and labels"""
54
+ return torch.abs(outputs - labels).mean().item()
55
+
56
+
57
+ def evaluate_model(model, test_loader, criterion):
58
+ model.eval()
59
+ test_loss = 0.0
60
+ mae_sum = 0.0
61
+ all_predictions = []
62
+ all_labels = []
63
+
64
+ # For debugging
65
+ debug_samples = []
66
+
67
+ with torch.no_grad():
68
+ for mfcc, image, label in test_loader:
69
+ mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
70
+ output = model(mfcc, image)
71
+ label = label.view(-1, 1).float()
72
+
73
+ # Store debug samples (handling batch dimension properly)
74
+ if len(debug_samples) < 5:
75
+ # Extract individual samples from the batch
76
+ for i in range(min(len(output), 5 - len(debug_samples))):
77
+ debug_samples.append((output[i][0].item(), label[i][0].item()))
78
+
79
+ # Calculate MSE loss
80
+ loss = criterion(output, label)
81
+ test_loss += loss.item()
82
+
83
+ # Calculate MAE
84
+ mae = torch.abs(output - label).mean()
85
+ mae_sum += mae.item()
86
+
87
+ # Store predictions and labels for additional analysis
88
+ all_predictions.extend(output.cpu().numpy())
89
+ all_labels.extend(label.cpu().numpy())
90
+
91
+ avg_loss = test_loss / len(test_loader)
92
+ avg_mae = mae_sum / len(test_loader)
93
+
94
+ # Convert to numpy arrays for easier analysis
95
+ all_predictions = np.array(all_predictions).flatten()
96
+ all_labels = np.array(all_labels).flatten()
97
+
98
+ # Print debug samples
99
+ print("\nDEBUG SAMPLES (Prediction, Label):")
100
+ for i, (pred, label) in enumerate(debug_samples):
101
+ print(f"Sample {i+1}: Prediction = {pred:.4f}, Label = {label:.4f}, Difference = {abs(pred-label):.4f}")
102
+
103
+ return avg_loss, avg_mae, all_predictions, all_labels
104
+
105
+
106
+ def train_model(model_type):
107
+ try:
108
+ # Create model based on type
109
+ if model_type == "mobilenet_gru":
110
+ model = MobileNetGRUModel(
111
+ gru_hidden_size=32,
112
+ gru_layers=1,
113
+ fc_hidden_size=fc_hidden_size,
114
+ dropout_rate=dropout_rate
115
+ ).to(device)
116
+ model_name = "MobileNetGRU"
117
+ elif model_type == "efficientnet_cnn":
118
+ model = EfficientNetCNNModel(
119
+ fc_hidden_size=fc_hidden_size,
120
+ dropout_rate=dropout_rate
121
+ ).to(device)
122
+ model_name = "EfficientNetCNN"
123
+ elif model_type == "squeezenet_transformer":
124
+ model = SqueezeNetTransformerModel(
125
+ nhead=4,
126
+ dim_feedforward=128,
127
+ fc_hidden_size=fc_hidden_size,
128
+ dropout_rate=dropout_rate
129
+ ).to(device)
130
+ model_name = "SqueezeNetTransformer"
131
+ else:
132
+ raise ValueError(f"Unknown model type: {model_type}")
133
+
134
+ # Data loading
135
+ data_dir = "./processed/"
136
+ dataset = PreprocessedDataset(data_dir)
137
+ n_samples = len(dataset)
138
+
139
+ # Check label range
140
+ all_labels = []
141
+ for i in range(min(10, len(dataset))):
142
+ _, _, label = dataset[i]
143
+ all_labels.append(label)
144
+
145
+ print("\nLABEL RANGE CHECK:")
146
+ print(f"Sample labels: {all_labels}")
147
+ print(f"Min label: {min(all_labels)}, Max label: {max(all_labels)}")
148
+
149
+ train_size = int(0.7 * n_samples)
150
+ val_size = int(0.2 * n_samples)
151
+ test_size = n_samples - train_size - val_size
152
+
153
+ train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
154
+ dataset, [train_size, val_size, test_size]
155
+ )
156
+
157
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
158
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
159
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
160
+
161
+ # Loss function and optimizer
162
+ criterion = torch.nn.MSELoss()
163
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
164
+
165
+ # TensorBoard
166
+ writer = SummaryWriter(f"runs/{model_name}/")
167
+ global_step = 0
168
+
169
+ print(f"\033[92mINFO\033[0m: Training {model_name} model for {epochs} epochs")
170
+ print(f"\033[92mINFO\033[0m: Training samples: {len(train_dataset)}")
171
+ print(f"\033[92mINFO\033[0m: Validation samples: {len(val_dataset)}")
172
+ print(f"\033[92mINFO\033[0m: Test samples: {len(test_dataset)}")
173
+ print(f"\033[92mINFO\033[0m: Batch size: {batch_size}")
174
+ print(f"\033[92mINFO\033[0m: Learning rate: {learning_rate}")
175
+ print(f"\033[92mINFO\033[0m: Dropout rate: {dropout_rate}")
176
+
177
+ best_val_loss = float('inf')
178
+ best_model_path = None
179
+
180
+ # Calculate model size
181
+ model_size = sum(p.numel() for p in model.parameters()) / 1e6 # in millions
182
+ print(f"\033[92mINFO\033[0m: Model parameters: {model_size:.2f}M")
183
+
184
+ # Training loop
185
+ for epoch in range(epochs):
186
+ print(f"\033[92mINFO\033[0m: Training epoch ({epoch+1}/{epochs})")
187
+
188
+ model.train()
189
+ running_loss = 0.0
190
+ running_mae = 0.0
191
+ n_batches = 0
192
+
193
+ start_time = time.time()
194
+
195
+ try:
196
+ for mfcc, image, label in train_loader:
197
+ mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
198
+
199
+ optimizer.zero_grad()
200
+ output = model(mfcc, image)
201
+ label = label.view(-1, 1).float()
202
+ loss = criterion(output, label)
203
+ loss.backward()
204
+ optimizer.step()
205
+
206
+ running_loss += loss.item()
207
+ running_mae += calculate_mae(output, label)
208
+ n_batches += 1
209
+
210
+ writer.add_scalar("Training/Loss", loss.item(), global_step)
211
+ writer.add_scalar("Training/MAE", calculate_mae(output, label), global_step)
212
+ global_step += 1
213
+ except Exception as e:
214
+ print(f"\033[91mERR!\033[0m: {e}")
215
+
216
+ epoch_time = time.time() - start_time
217
+
218
+ # Validation phase
219
+ model.eval()
220
+ val_loss = 0.0
221
+ val_mae = 0.0
222
+ val_batches = 0
223
+
224
+ with torch.no_grad():
225
+ try:
226
+ for mfcc, image, label in val_loader:
227
+ mfcc, image, label = (
228
+ mfcc.to(device),
229
+ image.to(device),
230
+ label.to(device),
231
+ )
232
+ output = model(mfcc, image)
233
+ label = label.view(-1, 1).float()
234
+
235
+ # Calculate loss
236
+ loss = criterion(output, label)
237
+ val_loss += loss.item()
238
+
239
+ # Calculate MAE
240
+ val_mae += calculate_mae(output, label)
241
+ val_batches += 1
242
+ except Exception as e:
243
+ print(f"\033[91mERR!\033[0m: {e}")
244
+
245
+ avg_train_loss = running_loss / n_batches
246
+ avg_train_mae = running_mae / n_batches
247
+ avg_val_loss = val_loss / val_batches
248
+ avg_val_mae = val_mae / val_batches
249
+
250
+ # Record validation metrics
251
+ writer.add_scalar("Validation/Loss", avg_val_loss, epoch)
252
+ writer.add_scalar("Validation/MAE", avg_val_mae, epoch)
253
+
254
+ print(
255
+ f"Epoch [{epoch+1}/{epochs}], Time: {epoch_time:.2f}s, "
256
+ f"Train Loss: {avg_train_loss:.4f}, Train MAE: {avg_train_mae:.4f}, "
257
+ f"Val Loss: {avg_val_loss:.4f}, Val MAE: {avg_val_mae:.4f}"
258
+ )
259
+
260
+ # Save model checkpoint
261
+ timestamp = time.strftime("%Y%m%d-%H%M%S")
262
+ model_path = f"models/{model_name}_model_{epoch+1}_{timestamp}.pt"
263
+ torch.save(model.state_dict(), model_path)
264
+
265
+ # Save the best model based on validation loss
266
+ if avg_val_loss < best_val_loss:
267
+ best_val_loss = avg_val_loss
268
+ best_model_path = model_path
269
+ print(f"\033[92mINFO\033[0m: New best model saved with validation loss: {best_val_loss:.4f}")
270
+
271
+ print(
272
+ f"\033[92mINFO\033[0m: Model checkpoint epoch [{epoch+1}/{epochs}] saved: {model_path}"
273
+ )
274
+
275
+ print(f"\033[92mINFO\033[0m: Training complete")
276
+
277
+ # Load the best model for testing
278
+ print(f"\033[92mINFO\033[0m: Loading best model from {best_model_path} for testing")
279
+ model.load_state_dict(torch.load(best_model_path))
280
+
281
+ # Evaluate on test set
282
+ test_loss, test_mae, predictions, labels = evaluate_model(model, test_loader, criterion)
283
+
284
+ # Calculate additional metrics
285
+ max_error = np.max(np.abs(predictions - labels))
286
+ min_error = np.min(np.abs(predictions - labels))
287
+
288
+ print("\n" + "="*50)
289
+ print(f"TEST RESULTS FOR {model_name}:")
290
+ print(f"Test Loss (MSE): {test_loss:.4f}")
291
+ print(f"Mean Absolute Error: {test_mae:.4f}")
292
+ print(f"Maximum Absolute Error: {max_error:.4f}")
293
+ print(f"Minimum Absolute Error: {min_error:.4f}")
294
+
295
+ # Add test results to TensorBoard
296
+ writer.add_scalar("Test/MSE", test_loss, 0)
297
+ writer.add_scalar("Test/MAE", test_mae, 0)
298
+ writer.add_scalar("Test/Max_Error", max_error, 0)
299
+ writer.add_scalar("Test/Min_Error", min_error, 0)
300
+
301
+ # Create a histogram of absolute errors
302
+ abs_errors = np.abs(predictions - labels)
303
+ writer.add_histogram("Test/Absolute_Errors", abs_errors, 0)
304
+
305
+ print("="*50)
306
+
307
+ # Final summary
308
+ print("\nTRAINING SUMMARY:")
309
+ print(f"Model: {model_name}")
310
+ print(f"Model Size: {model_size:.2f}M parameters")
311
+ print(f"Best Validation Loss: {best_val_loss:.4f}")
312
+ print(f"Final Test Loss: {test_loss:.4f}")
313
+ print(f"Final Test MAE: {test_mae:.4f}")
314
+ print(f"Best model saved at: {best_model_path}")
315
+
316
+ writer.close()
317
+
318
+ # Return metrics for comparison
319
+ return {
320
+ "model_name": model_name,
321
+ "model_size": model_size,
322
+ "val_loss": best_val_loss,
323
+ "test_loss": test_loss,
324
+ "test_mae": test_mae,
325
+ "model_path": best_model_path
326
+ }
327
+
328
+ except Exception as e:
329
+ print(f"\033[91mERR!\033[0m: Error training {model_type}: {e}")
330
+ # Return a placeholder result
331
+ return {
332
+ "model_name": model_type,
333
+ "model_size": 0,
334
+ "val_loss": float('inf'),
335
+ "test_loss": float('inf'),
336
+ "test_mae": float('inf'),
337
+ "model_path": None,
338
+ "error": str(e)
339
+ }
340
+
341
+
342
+ def test_cpu_inference(model_path, model_type):
343
+ """Test CPU inference speed for the given model"""
344
+ # Create model based on type
345
+ if model_type == "mobilenet_gru":
346
+ model = MobileNetGRUModel(
347
+ gru_hidden_size=32,
348
+ gru_layers=1,
349
+ fc_hidden_size=fc_hidden_size,
350
+ dropout_rate=dropout_rate
351
+ )
352
+ model_name = "MobileNetGRU"
353
+ elif model_type == "efficientnet_cnn":
354
+ model = EfficientNetCNNModel(
355
+ fc_hidden_size=fc_hidden_size,
356
+ dropout_rate=dropout_rate
357
+ )
358
+ model_name = "EfficientNetCNN"
359
+ elif model_type == "squeezenet_transformer":
360
+ model = SqueezeNetTransformerModel(
361
+ nhead=4,
362
+ dim_feedforward=128,
363
+ fc_hidden_size=fc_hidden_size,
364
+ dropout_rate=dropout_rate
365
+ )
366
+ model_name = "SqueezeNetTransformer"
367
+ else:
368
+ raise ValueError(f"Unknown model type: {model_type}")
369
+
370
+ # Load model weights
371
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
372
+ model.eval()
373
+
374
+ # Create dummy input
375
+ dummy_mfcc = torch.randn(1, 10, 376) # Batch size 1, 10 time steps, 376 features
376
+ dummy_image = torch.randn(1, 3, 224, 224) # Batch size 1, 3 channels, 224x224 image
377
+
378
+ # Warm-up
379
+ for _ in range(10):
380
+ _ = model(dummy_mfcc, dummy_image)
381
+
382
+ # Measure inference time
383
+ num_runs = 100
384
+ start_time = time.time()
385
+ for _ in range(num_runs):
386
+ _ = model(dummy_mfcc, dummy_image)
387
+ end_time = time.time()
388
+
389
+ avg_time = (end_time - start_time) / num_runs
390
+
391
+ print(f"\n{model_name} CPU Inference Time:")
392
+ print(f"Average over {num_runs} runs: {avg_time*1000:.2f} ms")
393
+
394
+ return avg_time
395
+
396
+
397
+ if __name__ == "__main__":
398
+ parser = argparse.ArgumentParser(description="Train and evaluate efficient models")
399
+ parser.add_argument(
400
+ "--model",
401
+ type=str,
402
+ choices=["mobilenet_gru", "efficientnet_cnn", "squeezenet_transformer", "all"],
403
+ default="all",
404
+ help="Model architecture to train"
405
+ )
406
+ args = parser.parse_args()
407
+
408
+ results = []
409
+
410
+ if args.model == "all":
411
+ # Train all models
412
+ for model_type in ["mobilenet_gru", "efficientnet_cnn", "squeezenet_transformer"]:
413
+ print(f"\n\n{'='*50}")
414
+ print(f"TRAINING {model_type.upper()}")
415
+ print(f"{'='*50}\n")
416
+ result = train_model(model_type)
417
+ results.append(result)
418
+
419
+ # Test CPU inference
420
+ inference_time = test_cpu_inference(result["model_path"], model_type)
421
+ result["inference_time"] = inference_time
422
+ else:
423
+ # Train specific model
424
+ result = train_model(args.model)
425
+ results.append(result)
426
+
427
+ # Test CPU inference
428
+ inference_time = test_cpu_inference(result["model_path"], args.model)
429
+ result["inference_time"] = inference_time
430
+
431
+ # Compare results
432
+ print("\n\n" + "="*80)
433
+ print("MODEL COMPARISON")
434
+ print("="*80)
435
+ print(f"{'Model':<25} {'Size (M)':<10} {'Val Loss':<10} {'Test Loss':<10} {'Test MAE':<10} {'CPU Time (ms)':<15}")
436
+ print("-"*80)
437
+
438
+ for result in results:
439
+ print(f"{result['model_name']:<25} {result['model_size']:<10.2f} {result['val_loss']:<10.4f} "
440
+ f"{result['test_loss']:<10.4f} {result['test_mae']:<10.4f} {result['inference_time']*1000:<15.2f}")
441
+
442
+ print("="*80)
443
+
444
+ # Find best model
445
+ best_model = min(results, key=lambda x: x["test_mae"])
446
+ print(f"\nBEST MODEL: {best_model['model_name']}")
447
+ print(f"Test MAE: {best_model['test_mae']:.4f}")
448
+ print(f"CPU Inference Time: {best_model['inference_time']*1000:.2f} ms")
449
+ print(f"Model Path: {best_model['model_path']}")