Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- README.md +3 -9
- __pycache__/infer.cpython-39.pyc +0 -0
- __pycache__/preprocess.cpython-39.pyc +0 -0
- __pycache__/train.cpython-39.pyc +0 -0
- __pycache__/train_2.cpython-39.pyc +0 -0
- app.py +284 -0
- infer.py +52 -0
- preprocess.py +45 -0
- preprocess_file.py +149 -0
- temp/temp_audio.wav +0 -0
- temp/temp_image.jpg +3 -0
- train.py +283 -0
- train_2.py +449 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
temp/temp_image.jpg filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,12 +1,6 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji: 👁
|
4 |
-
colorFrom: pink
|
5 |
-
colorTo: red
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 5.23.3
|
8 |
app_file: app.py
|
9 |
-
|
|
|
10 |
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: watermelon
|
|
|
|
|
|
|
|
|
|
|
3 |
app_file: app.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 4.44.1
|
6 |
---
|
|
|
|
__pycache__/infer.cpython-39.pyc
ADDED
Binary file (1.48 kB). View file
|
|
__pycache__/preprocess.cpython-39.pyc
ADDED
Binary file (1.27 kB). View file
|
|
__pycache__/train.cpython-39.pyc
ADDED
Binary file (7.13 kB). View file
|
|
__pycache__/train_2.cpython-39.pyc
ADDED
Binary file (10.7 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, torchaudio, torchvision
|
2 |
+
import os
|
3 |
+
import gradio as gr
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from preprocess import process_audio_data, process_image_data
|
7 |
+
from train import WatermelonModel
|
8 |
+
from infer import infer
|
9 |
+
|
10 |
+
def load_model(model_path):
|
11 |
+
global device
|
12 |
+
device = torch.device(
|
13 |
+
"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
14 |
+
)
|
15 |
+
print(f"\033[92mINFO\033[0m: Using device: {device}")
|
16 |
+
|
17 |
+
# Check if the file exists
|
18 |
+
if not os.path.exists(model_path):
|
19 |
+
raise FileNotFoundError(f"Model file not found at {model_path}")
|
20 |
+
|
21 |
+
# Check if the file is empty or very small
|
22 |
+
file_size = os.path.getsize(model_path)
|
23 |
+
if file_size < 1000: # Less than 1KB is suspiciously small for a model
|
24 |
+
print(f"\033[93mWARNING\033[0m: Model file size is only {file_size} bytes, which is suspiciously small")
|
25 |
+
|
26 |
+
try:
|
27 |
+
model = WatermelonModel().to(device)
|
28 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
29 |
+
model.eval()
|
30 |
+
print(f"\033[92mINFO\033[0m: Loaded model from {model_path}")
|
31 |
+
return model
|
32 |
+
except RuntimeError as e:
|
33 |
+
if "failed finding central directory" in str(e):
|
34 |
+
print(f"\033[91mERROR\033[0m: The model file at {model_path} appears to be corrupted.")
|
35 |
+
print("This can happen if:")
|
36 |
+
print(" 1. The model saving process was interrupted")
|
37 |
+
print(" 2. The file was not properly downloaded")
|
38 |
+
print(" 3. The path points to a file that is not a valid PyTorch model")
|
39 |
+
print(f"File size: {file_size} bytes")
|
40 |
+
raise
|
41 |
+
|
42 |
+
if __name__ == "__main__":
|
43 |
+
import argparse
|
44 |
+
|
45 |
+
parser = argparse.ArgumentParser(description="Watermelon sweetness predictor")
|
46 |
+
parser.add_argument("--model_path", type=str, required=True, help="Path to the trained model")
|
47 |
+
args = parser.parse_args()
|
48 |
+
|
49 |
+
model = load_model(args.model_path)
|
50 |
+
|
51 |
+
def predict(audio, image):
|
52 |
+
try:
|
53 |
+
# Debug audio input
|
54 |
+
print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
|
55 |
+
print(f"\033[92mDEBUG\033[0m: Audio input value: {audio}")
|
56 |
+
|
57 |
+
# Handle different formats of audio input from Gradio
|
58 |
+
if audio is None:
|
59 |
+
return "Error: No audio provided. Please upload or record audio."
|
60 |
+
|
61 |
+
if isinstance(audio, tuple) and len(audio) >= 2:
|
62 |
+
sr, audio_data = audio[0], audio[-1]
|
63 |
+
print(f"\033[92mDEBUG\033[0m: Audio format: sr={sr}, audio_data shape={audio_data.shape if hasattr(audio_data, 'shape') else 'no shape'}")
|
64 |
+
elif isinstance(audio, tuple) and len(audio) == 1:
|
65 |
+
# Handle single element tuple
|
66 |
+
audio_data = audio[0]
|
67 |
+
sr = 44100 # Assume default sample rate
|
68 |
+
print(f"\033[92mDEBUG\033[0m: Single element audio tuple, using default sr={sr}")
|
69 |
+
elif isinstance(audio, np.ndarray):
|
70 |
+
# Handle direct numpy array
|
71 |
+
audio_data = audio
|
72 |
+
sr = 44100 # Assume default sample rate
|
73 |
+
print(f"\033[92mDEBUG\033[0m: Audio is numpy array, using default sr={sr}")
|
74 |
+
else:
|
75 |
+
return f"Error: Unexpected audio format: {type(audio)}"
|
76 |
+
|
77 |
+
# Ensure audio_data is correctly shaped
|
78 |
+
if isinstance(audio_data, np.ndarray):
|
79 |
+
# Make sure we have a 2D array
|
80 |
+
if len(audio_data.shape) == 1:
|
81 |
+
audio_data = np.expand_dims(audio_data, axis=0)
|
82 |
+
print(f"\033[92mDEBUG\033[0m: Reshaped 1D audio to 2D: {audio_data.shape}")
|
83 |
+
|
84 |
+
# If channels are the second dimension, transpose
|
85 |
+
if len(audio_data.shape) == 2 and audio_data.shape[0] > audio_data.shape[1]:
|
86 |
+
audio_data = np.transpose(audio_data)
|
87 |
+
print(f"\033[92mDEBUG\033[0m: Transposed audio shape to: {audio_data.shape}")
|
88 |
+
|
89 |
+
# Convert to tensor
|
90 |
+
audio_tensor = torch.tensor(audio_data).float()
|
91 |
+
print(f"\033[92mDEBUG\033[0m: Audio tensor shape: {audio_tensor.shape}")
|
92 |
+
|
93 |
+
# Process audio data and handle None case
|
94 |
+
mfcc = process_audio_data(audio_tensor, sr)
|
95 |
+
if mfcc is None:
|
96 |
+
return "Error: Failed to process audio data. Make sure your audio contains a clear tapping sound."
|
97 |
+
|
98 |
+
mfcc = mfcc.to(device)
|
99 |
+
print(f"\033[92mDEBUG\033[0m: MFCC shape: {mfcc.shape}")
|
100 |
+
|
101 |
+
# Debug image input
|
102 |
+
print(f"\033[92mDEBUG\033[0m: Image input type: {type(image)}")
|
103 |
+
print(f"\033[92mDEBUG\033[0m: Image shape: {image.shape if hasattr(image, 'shape') else 'No shape'}")
|
104 |
+
|
105 |
+
# Process image data and handle None case
|
106 |
+
if image is None:
|
107 |
+
return "Error: No image provided. Please upload an image."
|
108 |
+
|
109 |
+
# Handle different image formats
|
110 |
+
if isinstance(image, np.ndarray):
|
111 |
+
# Check if image is properly formatted (H, W, C) with 3 channels
|
112 |
+
if len(image.shape) == 3 and image.shape[2] == 3:
|
113 |
+
# Convert to tensor with shape (C, H, W) as expected by PyTorch
|
114 |
+
img = torch.tensor(image).float().permute(2, 0, 1)
|
115 |
+
print(f"\033[92mDEBUG\033[0m: Converted image to tensor with shape: {img.shape}")
|
116 |
+
elif len(image.shape) == 2:
|
117 |
+
# Grayscale image, expand to 3 channels
|
118 |
+
img = torch.tensor(image).float().unsqueeze(0).repeat(3, 1, 1)
|
119 |
+
print(f"\033[92mDEBUG\033[0m: Converted grayscale image to RGB tensor with shape: {img.shape}")
|
120 |
+
else:
|
121 |
+
return f"Error: Unexpected image shape: {image.shape}. Expected RGB or grayscale image."
|
122 |
+
else:
|
123 |
+
return f"Error: Unexpected image format: {type(image)}. Expected numpy array."
|
124 |
+
|
125 |
+
# Scale pixel values to [0, 1] if needed
|
126 |
+
if img.max() > 1.0:
|
127 |
+
img = img / 255.0
|
128 |
+
print(f"\033[92mDEBUG\033[0m: Scaled image pixel values to range [0, 1]")
|
129 |
+
|
130 |
+
# Get image dimensions and check if they're reasonable
|
131 |
+
print(f"\033[92mDEBUG\033[0m: Final image tensor shape before processing: {img.shape}")
|
132 |
+
|
133 |
+
# Process image
|
134 |
+
try:
|
135 |
+
img_processed = process_image_data(img)
|
136 |
+
if img_processed is None:
|
137 |
+
return "Error: Failed to process image data. Make sure your image clearly shows a watermelon."
|
138 |
+
|
139 |
+
img_processed = img_processed.to(device)
|
140 |
+
print(f"\033[92mDEBUG\033[0m: Processed image shape: {img_processed.shape}")
|
141 |
+
except Exception as e:
|
142 |
+
print(f"\033[91mERROR\033[0m: Image processing error: {str(e)}")
|
143 |
+
return f"Error in image processing: {str(e)}"
|
144 |
+
|
145 |
+
# Run inference
|
146 |
+
try:
|
147 |
+
# Based on the error, it seems infer() expects file paths, not tensors
|
148 |
+
# Let's create temporary files for the processed data
|
149 |
+
temp_dir = os.path.join(os.getcwd(), "temp")
|
150 |
+
os.makedirs(temp_dir, exist_ok=True)
|
151 |
+
|
152 |
+
# Save the audio to a temporary file if infer expects a file path
|
153 |
+
temp_audio_path = os.path.join(temp_dir, "temp_audio.wav")
|
154 |
+
if not isinstance(audio, str) and isinstance(audio, tuple) and len(audio) >= 2:
|
155 |
+
# If we have the original audio data and sample rate
|
156 |
+
audio_array = audio[-1]
|
157 |
+
sr = audio[0]
|
158 |
+
|
159 |
+
# Check if the audio array is valid
|
160 |
+
if audio_array.size == 0:
|
161 |
+
return "Error: Audio data is empty. Please record a longer audio clip."
|
162 |
+
|
163 |
+
# Get the duration of the audio
|
164 |
+
duration = audio_array.shape[-1] / sr
|
165 |
+
print(f"\033[92mDEBUG\033[0m: Audio duration: {duration:.2f} seconds")
|
166 |
+
|
167 |
+
# Check if we have at least 1 second of audio - but don't reject, just pad if needed
|
168 |
+
min_duration = 1.0 # minimum 1 second of audio
|
169 |
+
if duration < min_duration:
|
170 |
+
print(f"\033[93mWARNING\033[0m: Audio is shorter than {min_duration} seconds. Padding will be applied.")
|
171 |
+
# Calculate samples needed to reach minimum duration
|
172 |
+
samples_needed = int(min_duration * sr) - audio_array.shape[-1]
|
173 |
+
# Pad with zeros
|
174 |
+
padding = np.zeros((audio_array.shape[0], samples_needed), dtype=audio_array.dtype)
|
175 |
+
audio_array = np.concatenate([audio_array, padding], axis=1)
|
176 |
+
print(f"\033[92mDEBUG\033[0m: Padded audio to shape: {audio_array.shape}")
|
177 |
+
|
178 |
+
# Make sure audio has 2 dimensions
|
179 |
+
if len(audio_array.shape) == 1:
|
180 |
+
audio_array = np.expand_dims(audio_array, axis=0)
|
181 |
+
|
182 |
+
print(f"\033[92mDEBUG\033[0m: Audio array shape before saving: {audio_array.shape}, sr: {sr}")
|
183 |
+
|
184 |
+
# Make sure it's in the right format for torchaudio.save
|
185 |
+
audio_tensor = torch.tensor(audio_array).float()
|
186 |
+
if audio_tensor.dim() == 1:
|
187 |
+
audio_tensor = audio_tensor.unsqueeze(0)
|
188 |
+
|
189 |
+
torchaudio.save(temp_audio_path, audio_tensor, sr)
|
190 |
+
print(f"\033[92mDEBUG\033[0m: Saved temporary audio file to {temp_audio_path}")
|
191 |
+
|
192 |
+
# Let's also process the audio here to verify it works
|
193 |
+
test_mfcc = process_audio_data(audio_tensor, sr)
|
194 |
+
if test_mfcc is None:
|
195 |
+
return "Error: Unable to process the audio. Please try recording a different audio sample."
|
196 |
+
else:
|
197 |
+
print(f"\033[92mDEBUG\033[0m: Audio pre-check passed. MFCC shape: {test_mfcc.shape}")
|
198 |
+
|
199 |
+
audio_path = temp_audio_path
|
200 |
+
else:
|
201 |
+
# If we don't have a valid path, return an error
|
202 |
+
return "Error: Cannot process audio for inference. Invalid audio format."
|
203 |
+
|
204 |
+
# Save the image to a temporary file if infer expects a file path
|
205 |
+
temp_image_path = os.path.join(temp_dir, "temp_image.jpg")
|
206 |
+
if isinstance(image, np.ndarray):
|
207 |
+
import cv2
|
208 |
+
cv2.imwrite(temp_image_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
|
209 |
+
print(f"\033[92mDEBUG\033[0m: Saved temporary image file to {temp_image_path}")
|
210 |
+
image_path = temp_image_path
|
211 |
+
else:
|
212 |
+
# If we don't have a valid image, return an error
|
213 |
+
return "Error: Cannot process image for inference. Invalid image format."
|
214 |
+
|
215 |
+
# Create a modified version of infer that handles None returns
|
216 |
+
def safe_infer(audio_path, image_path, model, device):
|
217 |
+
try:
|
218 |
+
return infer(audio_path, image_path, model, device)
|
219 |
+
except Exception as e:
|
220 |
+
print(f"\033[91mERROR\033[0m: Error in infer function: {str(e)}")
|
221 |
+
# Try a more direct approach
|
222 |
+
try:
|
223 |
+
# Load audio and process
|
224 |
+
audio, sr = torchaudio.load(audio_path)
|
225 |
+
mfcc = process_audio_data(audio, sr)
|
226 |
+
if mfcc is None:
|
227 |
+
raise ValueError("Audio processing failed - MFCC is None")
|
228 |
+
mfcc = mfcc.to(device)
|
229 |
+
|
230 |
+
# Load image and process
|
231 |
+
image = cv2.imread(image_path)
|
232 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
233 |
+
image_tensor = torch.tensor(image).float().permute(2, 0, 1) / 255.0
|
234 |
+
img_processed = process_image_data(image_tensor)
|
235 |
+
if img_processed is None:
|
236 |
+
raise ValueError("Image processing failed - processed image is None")
|
237 |
+
img_processed = img_processed.to(device)
|
238 |
+
|
239 |
+
# Run model inference
|
240 |
+
with torch.no_grad():
|
241 |
+
prediction = model(mfcc, img_processed)
|
242 |
+
return prediction
|
243 |
+
except Exception as e2:
|
244 |
+
print(f"\033[91mERROR\033[0m: Fallback inference also failed: {str(e2)}")
|
245 |
+
raise
|
246 |
+
|
247 |
+
# Call our safer version
|
248 |
+
print(f"\033[92mDEBUG\033[0m: Calling safe_infer with audio_path={audio_path}, image_path={image_path}")
|
249 |
+
sweetness = safe_infer(audio_path, image_path, model, device)
|
250 |
+
if sweetness is None:
|
251 |
+
return "Error: The model was unable to make a prediction. Please try with different inputs."
|
252 |
+
|
253 |
+
print(f"\033[92mDEBUG\033[0m: Inference result: {sweetness.item()}")
|
254 |
+
return f"Predicted Sweetness: {sweetness.item():.2f}/10"
|
255 |
+
except Exception as e:
|
256 |
+
import traceback
|
257 |
+
print(f"\033[91mERROR\033[0m: Inference failed: {str(e)}")
|
258 |
+
print(f"\033[91mTraceback\033[0m: {traceback.format_exc()}")
|
259 |
+
return f"Error during inference: {str(e)}"
|
260 |
+
|
261 |
+
except Exception as e:
|
262 |
+
import traceback
|
263 |
+
print(f"\033[91mERROR\033[0m: Prediction failed: {str(e)}")
|
264 |
+
print(f"\033[91mTraceback\033[0m: {traceback.format_exc()}")
|
265 |
+
return f"Error processing input: {str(e)}"
|
266 |
+
|
267 |
+
audio_input = gr.Audio(label="Upload or Record Audio")
|
268 |
+
image_input = gr.Image(label="Upload or Capture Image")
|
269 |
+
output = gr.Textbox(label="Predicted Sweetness")
|
270 |
+
|
271 |
+
interface = gr.Interface(
|
272 |
+
fn=predict,
|
273 |
+
inputs=[audio_input, image_input],
|
274 |
+
outputs=output,
|
275 |
+
title="Watermelon Sweetness Predictor",
|
276 |
+
description="Upload an audio file and an image to predict the sweetness of a watermelon."
|
277 |
+
)
|
278 |
+
|
279 |
+
try:
|
280 |
+
interface.launch(share=True) # Enable sharing to avoid localhost access issues
|
281 |
+
except Exception as e:
|
282 |
+
print(f"\033[91mERROR\033[0m: Failed to launch interface: {e}")
|
283 |
+
print("\033[93mTIP\033[0m: If you're running in a remote environment or container, try setting additional parameters:")
|
284 |
+
print(" interface.launch(server_name='0.0.0.0', share=True)")
|
infer.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, torchaudio
|
2 |
+
import argparse
|
3 |
+
from preprocess import process_audio_data, process_image_data
|
4 |
+
from train_2 import WatermelonModel
|
5 |
+
|
6 |
+
|
7 |
+
def infer(audio, image, model, device):
|
8 |
+
# Load and preprocess the input data
|
9 |
+
audio, sr = torchaudio.load(audio)
|
10 |
+
mfcc = process_audio_data(audio, sr).to(device)
|
11 |
+
img = process_image_data(image).to(device)
|
12 |
+
if mfcc is None or img is None:
|
13 |
+
return None
|
14 |
+
|
15 |
+
# Run inference
|
16 |
+
with torch.no_grad():
|
17 |
+
predicted_sweetness = model(mfcc, img).item()
|
18 |
+
|
19 |
+
return predicted_sweetness
|
20 |
+
|
21 |
+
|
22 |
+
if __name__ == "__main__":
|
23 |
+
parser = argparse.ArgumentParser(description="Run Watermelon Sweetness Prediction")
|
24 |
+
parser.add_argument(
|
25 |
+
"--model_path", type=str, required=True, help="Path to the saved model file"
|
26 |
+
)
|
27 |
+
parser.add_argument(
|
28 |
+
"--audio_path", type=str, required=True, help="Path to audio file"
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
"--image_path", type=str, required=True, help="Path to image file"
|
32 |
+
)
|
33 |
+
args = parser.parse_args()
|
34 |
+
|
35 |
+
# Initialize the model and device
|
36 |
+
print(f"\033[92mINFO\033[0m: PyTorch version: {torch.__version__}")
|
37 |
+
device = torch.device(
|
38 |
+
"cuda"
|
39 |
+
if torch.cuda.is_available()
|
40 |
+
else "mps" if torch.backends.mps.is_available() else "cpu"
|
41 |
+
)
|
42 |
+
print(f"\033[92mINFO\033[0m: Using device: {device}")
|
43 |
+
model = WatermelonModel().to(device)
|
44 |
+
model.load_state_dict(torch.load(args.model_path, map_location=device))
|
45 |
+
|
46 |
+
# Example paths to audio and image files
|
47 |
+
audio_path = args.audio_patb
|
48 |
+
image_path = args.image_path
|
49 |
+
|
50 |
+
# Run inference
|
51 |
+
sweetness = infer(audio_path, image_path, model, device)
|
52 |
+
print(f"Predicted sweetness: {sweetness}")
|
preprocess.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchaudio
|
3 |
+
import torchvision
|
4 |
+
|
5 |
+
resample_rate = 16000
|
6 |
+
|
7 |
+
def process_audio_data(waveform, sample_rate):
|
8 |
+
try:
|
9 |
+
waveform = waveform[0] # 使用左声道
|
10 |
+
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=resample_rate)(waveform)
|
11 |
+
|
12 |
+
if waveform.size(0) < 3 * resample_rate:
|
13 |
+
waveform = torch.nn.functional.pad(waveform, (0, 3 * resample_rate - waveform.size(0)))
|
14 |
+
else:
|
15 |
+
waveform = waveform[: 3 * resample_rate]
|
16 |
+
|
17 |
+
mfcc = torchaudio.transforms.MFCC(
|
18 |
+
sample_rate=resample_rate,
|
19 |
+
n_mfcc=13,
|
20 |
+
melkwargs={
|
21 |
+
"n_fft": 256,
|
22 |
+
"win_length": 256,
|
23 |
+
"hop_length": 128,
|
24 |
+
"n_mels": 40,
|
25 |
+
}
|
26 |
+
)(waveform)
|
27 |
+
|
28 |
+
return mfcc
|
29 |
+
except Exception as e:
|
30 |
+
print(f"ERR!: Error in audio processing: {e}")
|
31 |
+
return None
|
32 |
+
|
33 |
+
def process_image_data(image):
|
34 |
+
try:
|
35 |
+
image = torchvision.transforms.Resize((1080, 1080))(image)
|
36 |
+
image = image / 255.0
|
37 |
+
image = torchvision.transforms.Normalize(
|
38 |
+
mean=[0.485, 0.456, 0.406],
|
39 |
+
std=[0.229, 0.224, 0.225]
|
40 |
+
)(image)
|
41 |
+
|
42 |
+
return image
|
43 |
+
except Exception as e:
|
44 |
+
print(f"ERR!: Error in image processing: {e}")
|
45 |
+
return None
|
preprocess_file.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import torch
|
4 |
+
import torchaudio
|
5 |
+
import torchvision
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
from concurrent.futures import ThreadPoolExecutor
|
8 |
+
from preprocess import process_audio_data, process_image_data, resample_rate
|
9 |
+
|
10 |
+
class PreprocessedDataset(Dataset):
|
11 |
+
def __init__(self, data_dir):
|
12 |
+
self.data_dir = data_dir
|
13 |
+
self.samples = [
|
14 |
+
os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".pt")
|
15 |
+
]
|
16 |
+
|
17 |
+
def __len__(self):
|
18 |
+
return len(self.samples)
|
19 |
+
|
20 |
+
def __getitem__(self, idx):
|
21 |
+
sample_path = self.samples[idx]
|
22 |
+
mfcc, image, label = torch.load(sample_path)
|
23 |
+
|
24 |
+
# Process data
|
25 |
+
mfcc = process_audio_data(mfcc, resample_rate)
|
26 |
+
image = process_image_data(image)
|
27 |
+
|
28 |
+
return mfcc, image, label
|
29 |
+
|
30 |
+
def load_audio_file(audio_path):
|
31 |
+
if not os.path.exists(audio_path):
|
32 |
+
raise FileNotFoundError(f"Audio file not found: {audio_path}")
|
33 |
+
|
34 |
+
try:
|
35 |
+
# Try the default torchaudio loader first
|
36 |
+
waveform, sample_rate = torchaudio.load(audio_path)
|
37 |
+
except Exception as e:
|
38 |
+
print(f"Warning: Could not load {audio_path} with torchaudio: {e}")
|
39 |
+
|
40 |
+
# Fall back to librosa (you'll need to install it: pip install librosa)
|
41 |
+
try:
|
42 |
+
import librosa
|
43 |
+
import numpy as np
|
44 |
+
|
45 |
+
waveform_np, sample_rate = librosa.load(audio_path, sr=None)
|
46 |
+
# Convert to torch tensor with shape [1, length] to match torchaudio format
|
47 |
+
waveform = torch.from_numpy(waveform_np[np.newaxis, :]).float()
|
48 |
+
print(f"Successfully loaded with librosa: {audio_path}")
|
49 |
+
except Exception as final_e:
|
50 |
+
raise RuntimeError(f"Failed to load audio file {audio_path} with all available methods: {final_e}")
|
51 |
+
|
52 |
+
return waveform, sample_rate
|
53 |
+
|
54 |
+
def load_image_file(image_path):
|
55 |
+
if not os.path.exists(image_path):
|
56 |
+
raise FileNotFoundError(f"Image file not found: {image_path}")
|
57 |
+
|
58 |
+
image = torchvision.io.read_image(image_path)
|
59 |
+
return image
|
60 |
+
|
61 |
+
def process_sample(sample_path, save_dir):
|
62 |
+
# Recursively search for audio and image files
|
63 |
+
audio_files = []
|
64 |
+
image_files = []
|
65 |
+
|
66 |
+
# Walk through all subdirectories
|
67 |
+
for root, _, files in os.walk(sample_path):
|
68 |
+
for file in files:
|
69 |
+
if file.lower().endswith(('.wav', '.mp3', '.flac')):
|
70 |
+
audio_files.append(os.path.join(root, file))
|
71 |
+
elif file.lower().endswith(('.jpg', '.jpeg', '.png')):
|
72 |
+
image_files.append(os.path.join(root, file))
|
73 |
+
|
74 |
+
if not audio_files:
|
75 |
+
print(f"Warning: No audio file found in {sample_path}. Skipping this sample.")
|
76 |
+
return
|
77 |
+
|
78 |
+
if not image_files:
|
79 |
+
print(f"Warning: No image file found in {sample_path}. Skipping this sample.")
|
80 |
+
return
|
81 |
+
|
82 |
+
# Use the first found audio and image files
|
83 |
+
audio_path = audio_files[0]
|
84 |
+
image_path = image_files[0]
|
85 |
+
|
86 |
+
print(f"Processing audio: {audio_path}")
|
87 |
+
print(f"Processing image: {image_path}")
|
88 |
+
|
89 |
+
waveform, sample_rate = load_audio_file(audio_path)
|
90 |
+
image = load_image_file(image_path)
|
91 |
+
|
92 |
+
# Process data
|
93 |
+
mfcc = process_audio_data(waveform, sample_rate)
|
94 |
+
processed_image = process_image_data(image)
|
95 |
+
|
96 |
+
# Save processed data
|
97 |
+
save_path = os.path.join(save_dir, f"{os.path.basename(sample_path)}.pt")
|
98 |
+
torch.save((mfcc, processed_image, float(os.path.basename(sample_path))), save_path)
|
99 |
+
print(f"Processed and saved: {save_path}")
|
100 |
+
|
101 |
+
def process_and_save(data_dir, save_dir):
|
102 |
+
os.makedirs(save_dir, exist_ok=True)
|
103 |
+
sample_paths = [os.path.join(data_dir, d) for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]
|
104 |
+
|
105 |
+
if not sample_paths:
|
106 |
+
print(f"Warning: No sample directories found in {data_dir}")
|
107 |
+
return
|
108 |
+
|
109 |
+
print(f"Found {len(sample_paths)} sample directories to process")
|
110 |
+
|
111 |
+
successful = 0
|
112 |
+
failed = 0
|
113 |
+
|
114 |
+
with ThreadPoolExecutor() as executor:
|
115 |
+
futures = [executor.submit(process_sample, path, save_dir) for path in sample_paths]
|
116 |
+
for future in futures:
|
117 |
+
try:
|
118 |
+
future.result() # Wait for all threads to complete
|
119 |
+
successful += 1
|
120 |
+
except Exception as e:
|
121 |
+
failed += 1
|
122 |
+
print(f"Error processing a sample: {e}")
|
123 |
+
|
124 |
+
print(f"Processing complete. Successfully processed: {successful}, Failed: {failed}")
|
125 |
+
|
126 |
+
if __name__ == "__main__":
|
127 |
+
import argparse
|
128 |
+
|
129 |
+
parser = argparse.ArgumentParser(description="Preprocess the dataset")
|
130 |
+
parser.add_argument(
|
131 |
+
"--data_dir",
|
132 |
+
type=str,
|
133 |
+
default="cleaned",
|
134 |
+
help="Path to the cleaned dataset directory",
|
135 |
+
)
|
136 |
+
parser.add_argument(
|
137 |
+
"--save_dir",
|
138 |
+
type=str,
|
139 |
+
default="processed",
|
140 |
+
help="Path to the processed dataset directory",
|
141 |
+
)
|
142 |
+
args = parser.parse_args()
|
143 |
+
|
144 |
+
print(f"Processing dataset from: {args.data_dir}")
|
145 |
+
print(f"Saving processed data to: {args.save_dir}")
|
146 |
+
|
147 |
+
process_and_save(args.data_dir, args.save_dir)
|
148 |
+
|
149 |
+
print("Preprocessing complete")
|
temp/temp_audio.wav
ADDED
Binary file (58 Bytes). View file
|
|
temp/temp_image.jpg
ADDED
![]() |
Git LFS Details
|
train.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
|
4 |
+
import torch, torchaudio, torchvision
|
5 |
+
from torch.utils.data import Dataset, DataLoader
|
6 |
+
from torch.utils.tensorboard import SummaryWriter
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
# 打印库的版本信息
|
10 |
+
print(f"\033[92mINFO\033[0m: PyTorch version: {torch.__version__}")
|
11 |
+
print(f"\033[92mINFO\033[0m: Torchaudio version: {torchaudio.__version__}")
|
12 |
+
print(f"\033[92mINFO\033[0m: Torchvision version: {torchvision.__version__}")
|
13 |
+
|
14 |
+
# 设备选择
|
15 |
+
device = torch.device(
|
16 |
+
"cuda"
|
17 |
+
if torch.cuda.is_available()
|
18 |
+
else "mps" if torch.backends.mps.is_available() else "cpu"
|
19 |
+
)
|
20 |
+
print(f"\033[92mINFO\033[0m: Using device: {device}")
|
21 |
+
|
22 |
+
# 超参数设置
|
23 |
+
batch_size = 1
|
24 |
+
epochs = 20
|
25 |
+
|
26 |
+
# 模型保存目录
|
27 |
+
os.makedirs("./models/", exist_ok=True)
|
28 |
+
|
29 |
+
|
30 |
+
class PreprocessedDataset(Dataset):
|
31 |
+
def __init__(self, data_dir):
|
32 |
+
self.data_dir = data_dir
|
33 |
+
self.samples = [
|
34 |
+
os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".pt")
|
35 |
+
]
|
36 |
+
|
37 |
+
def __len__(self):
|
38 |
+
return len(self.samples)
|
39 |
+
|
40 |
+
def __getitem__(self, idx):
|
41 |
+
sample_path = self.samples[idx]
|
42 |
+
mfcc, image, label = torch.load(sample_path)
|
43 |
+
return mfcc.float(), image.float(), label
|
44 |
+
|
45 |
+
|
46 |
+
class WatermelonModel(torch.nn.Module):
|
47 |
+
def __init__(self):
|
48 |
+
super(WatermelonModel, self).__init__()
|
49 |
+
|
50 |
+
# LSTM for audio features
|
51 |
+
self.lstm = torch.nn.LSTM(
|
52 |
+
input_size=376, hidden_size=64, num_layers=2, batch_first=True
|
53 |
+
)
|
54 |
+
self.lstm_fc = torch.nn.Linear(
|
55 |
+
64, 128
|
56 |
+
) # Convert LSTM output to 128-dim for merging
|
57 |
+
|
58 |
+
# ResNet50 for image features
|
59 |
+
self.resnet = torchvision.models.resnet50(pretrained=True)
|
60 |
+
self.resnet.fc = torch.nn.Linear(
|
61 |
+
self.resnet.fc.in_features, 128
|
62 |
+
) # Convert ResNet output to 128-dim for merging
|
63 |
+
|
64 |
+
# Fully connected layers for final prediction
|
65 |
+
self.fc1 = torch.nn.Linear(256, 64)
|
66 |
+
self.fc2 = torch.nn.Linear(64, 1)
|
67 |
+
self.relu = torch.nn.ReLU()
|
68 |
+
|
69 |
+
def forward(self, mfcc, image):
|
70 |
+
# LSTM branch
|
71 |
+
lstm_output, _ = self.lstm(mfcc)
|
72 |
+
lstm_output = lstm_output[:, -1, :] # Use the output of the last time step
|
73 |
+
lstm_output = self.lstm_fc(lstm_output)
|
74 |
+
|
75 |
+
# ResNet branch
|
76 |
+
resnet_output = self.resnet(image)
|
77 |
+
|
78 |
+
# Concatenate LSTM and ResNet outputs
|
79 |
+
merged = torch.cat((lstm_output, resnet_output), dim=1)
|
80 |
+
|
81 |
+
# Fully connected layers
|
82 |
+
output = self.relu(self.fc1(merged))
|
83 |
+
output = self.fc2(output)
|
84 |
+
|
85 |
+
return output
|
86 |
+
|
87 |
+
|
88 |
+
def evaluate_model(model, test_loader, criterion):
|
89 |
+
model.eval()
|
90 |
+
test_loss = 0.0
|
91 |
+
mae_sum = 0.0
|
92 |
+
all_predictions = []
|
93 |
+
all_labels = []
|
94 |
+
|
95 |
+
# For debugging
|
96 |
+
debug_samples = []
|
97 |
+
|
98 |
+
with torch.no_grad():
|
99 |
+
for mfcc, image, label in test_loader:
|
100 |
+
mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
|
101 |
+
output = model(mfcc, image)
|
102 |
+
label = label.view(-1, 1).float()
|
103 |
+
|
104 |
+
# Store debug samples
|
105 |
+
if len(debug_samples) < 5:
|
106 |
+
debug_samples.append((output.item(), label.item()))
|
107 |
+
|
108 |
+
# Calculate MSE loss
|
109 |
+
loss = criterion(output, label)
|
110 |
+
test_loss += loss.item()
|
111 |
+
|
112 |
+
# Calculate MAE
|
113 |
+
mae = torch.abs(output - label).mean()
|
114 |
+
mae_sum += mae.item()
|
115 |
+
|
116 |
+
# Store predictions and labels for additional analysis
|
117 |
+
all_predictions.extend(output.cpu().numpy())
|
118 |
+
all_labels.extend(label.cpu().numpy())
|
119 |
+
|
120 |
+
avg_loss = test_loss / len(test_loader)
|
121 |
+
avg_mae = mae_sum / len(test_loader)
|
122 |
+
|
123 |
+
# Convert to numpy arrays for easier analysis
|
124 |
+
all_predictions = np.array(all_predictions).flatten()
|
125 |
+
all_labels = np.array(all_labels).flatten()
|
126 |
+
|
127 |
+
# Print debug samples
|
128 |
+
print("\nDEBUG SAMPLES (Prediction, Label):")
|
129 |
+
for i, (pred, label) in enumerate(debug_samples):
|
130 |
+
print(f"Sample {i+1}: Prediction = {pred:.4f}, Label = {label:.4f}, Difference = {abs(pred-label):.4f}")
|
131 |
+
|
132 |
+
return avg_loss, avg_mae, all_predictions, all_labels
|
133 |
+
|
134 |
+
|
135 |
+
def train_model():
|
136 |
+
# 数据集加载
|
137 |
+
data_dir = "./processed/"
|
138 |
+
dataset = PreprocessedDataset(data_dir)
|
139 |
+
n_samples = len(dataset)
|
140 |
+
|
141 |
+
# Check label range
|
142 |
+
all_labels = []
|
143 |
+
for i in range(min(10, len(dataset))):
|
144 |
+
_, _, label = dataset[i]
|
145 |
+
all_labels.append(label)
|
146 |
+
|
147 |
+
print("\nLABEL RANGE CHECK:")
|
148 |
+
print(f"Sample labels: {all_labels}")
|
149 |
+
print(f"Min label: {min(all_labels)}, Max label: {max(all_labels)}")
|
150 |
+
|
151 |
+
train_size = int(0.7 * n_samples)
|
152 |
+
val_size = int(0.2 * n_samples)
|
153 |
+
test_size = n_samples - train_size - val_size
|
154 |
+
|
155 |
+
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
|
156 |
+
dataset, [train_size, val_size, test_size]
|
157 |
+
)
|
158 |
+
|
159 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
160 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
161 |
+
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
162 |
+
|
163 |
+
model = WatermelonModel().to(device)
|
164 |
+
|
165 |
+
# 损失函数和优化器
|
166 |
+
criterion = torch.nn.MSELoss()
|
167 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
168 |
+
|
169 |
+
# TensorBoard
|
170 |
+
writer = SummaryWriter("runs/")
|
171 |
+
global_step = 0
|
172 |
+
|
173 |
+
print(f"\033[92mINFO\033[0m: Training model for {epochs} epochs")
|
174 |
+
print(f"\033[92mINFO\033[0m: Training samples: {len(train_dataset)}")
|
175 |
+
print(f"\033[92mINFO\033[0m: Validation samples: {len(val_dataset)}")
|
176 |
+
print(f"\033[92mINFO\033[0m: Test samples: {len(test_dataset)}")
|
177 |
+
print(f"\033[92mINFO\033[0m: Batch size: {batch_size}")
|
178 |
+
|
179 |
+
best_val_loss = float('inf')
|
180 |
+
best_model_path = None
|
181 |
+
|
182 |
+
# 训练循环
|
183 |
+
for epoch in range(epochs):
|
184 |
+
print(f"\033[92mINFO\033[0m: Training epoch ({epoch+1}/{epochs})")
|
185 |
+
|
186 |
+
model.train()
|
187 |
+
running_loss = 0.0
|
188 |
+
try:
|
189 |
+
for mfcc, image, label in train_loader:
|
190 |
+
mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
|
191 |
+
|
192 |
+
optimizer.zero_grad()
|
193 |
+
output = model(mfcc, image)
|
194 |
+
label = label.view(-1, 1).float()
|
195 |
+
loss = criterion(output, label)
|
196 |
+
loss.backward()
|
197 |
+
optimizer.step()
|
198 |
+
|
199 |
+
running_loss += loss.item()
|
200 |
+
writer.add_scalar("Training Loss", loss.item(), global_step)
|
201 |
+
global_step += 1
|
202 |
+
except Exception as e:
|
203 |
+
print(f"\033[91mERR!\033[0m: {e}")
|
204 |
+
|
205 |
+
# 验证阶段
|
206 |
+
model.eval()
|
207 |
+
val_loss = 0.0
|
208 |
+
with torch.no_grad():
|
209 |
+
try:
|
210 |
+
for mfcc, image, label in val_loader:
|
211 |
+
mfcc, image, label = (
|
212 |
+
mfcc.to(device),
|
213 |
+
image.to(device),
|
214 |
+
label.to(device),
|
215 |
+
)
|
216 |
+
output = model(mfcc, image)
|
217 |
+
loss = criterion(output, label.view(-1, 1))
|
218 |
+
val_loss += loss.item()
|
219 |
+
except Exception as e:
|
220 |
+
print(f"\033[91mERR!\033[0m: {e}")
|
221 |
+
|
222 |
+
avg_val_loss = val_loss / len(val_loader)
|
223 |
+
|
224 |
+
# 记录验证损失
|
225 |
+
writer.add_scalar("Validation Loss", avg_val_loss, epoch)
|
226 |
+
|
227 |
+
print(
|
228 |
+
f"Epoch [{epoch+1}/{epochs}], Training Loss: {running_loss/len(train_loader):.4f}, "
|
229 |
+
f"Validation Loss: {avg_val_loss:.4f}"
|
230 |
+
)
|
231 |
+
|
232 |
+
# 保存模型检查点
|
233 |
+
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
234 |
+
model_path = f"models/model_{epoch+1}_{timestamp}.pt"
|
235 |
+
torch.save(model.state_dict(), model_path)
|
236 |
+
|
237 |
+
# Save the best model based on validation loss
|
238 |
+
if avg_val_loss < best_val_loss:
|
239 |
+
best_val_loss = avg_val_loss
|
240 |
+
best_model_path = model_path
|
241 |
+
print(f"\033[92mINFO\033[0m: New best model saved with validation loss: {best_val_loss:.4f}")
|
242 |
+
|
243 |
+
print(
|
244 |
+
f"\033[92mINFO\033[0m: Model checkpoint epoch [{epoch+1}/{epochs}] saved: {model_path}"
|
245 |
+
)
|
246 |
+
|
247 |
+
print(f"\033[92mINFO\033[0m: Training complete")
|
248 |
+
|
249 |
+
# Load the best model for testing
|
250 |
+
print(f"\033[92mINFO\033[0m: Loading best model from {best_model_path} for testing")
|
251 |
+
model.load_state_dict(torch.load(best_model_path))
|
252 |
+
|
253 |
+
# Evaluate on test set
|
254 |
+
test_loss, test_mae, predictions, labels = evaluate_model(model, test_loader, criterion)
|
255 |
+
|
256 |
+
# Calculate additional metrics
|
257 |
+
max_error = np.max(np.abs(predictions - labels))
|
258 |
+
min_error = np.min(np.abs(predictions - labels))
|
259 |
+
|
260 |
+
print("\n" + "="*50)
|
261 |
+
print("TEST RESULTS:")
|
262 |
+
print(f"Test Loss (MSE): {test_loss:.4f}")
|
263 |
+
print(f"Mean Absolute Error: {test_mae:.4f}")
|
264 |
+
print(f"Maximum Absolute Error: {max_error:.4f}")
|
265 |
+
print(f"Minimum Absolute Error: {min_error:.4f}")
|
266 |
+
|
267 |
+
# Add test results to TensorBoard
|
268 |
+
writer.add_scalar("Test/MSE", test_loss, 0)
|
269 |
+
writer.add_scalar("Test/MAE", test_mae, 0)
|
270 |
+
writer.add_scalar("Test/Max_Error", max_error, 0)
|
271 |
+
writer.add_scalar("Test/Min_Error", min_error, 0)
|
272 |
+
|
273 |
+
# Create a histogram of absolute errors
|
274 |
+
abs_errors = np.abs(predictions - labels)
|
275 |
+
writer.add_histogram("Test/Absolute_Errors", abs_errors, 0)
|
276 |
+
|
277 |
+
print("="*50)
|
278 |
+
|
279 |
+
writer.close()
|
280 |
+
|
281 |
+
|
282 |
+
if __name__ == "__main__":
|
283 |
+
train_model()
|
train_2.py
ADDED
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import argparse
|
4 |
+
import torch
|
5 |
+
import torchaudio
|
6 |
+
import torchvision
|
7 |
+
from torch.utils.data import Dataset, DataLoader
|
8 |
+
from torch.utils.tensorboard import SummaryWriter
|
9 |
+
import numpy as np
|
10 |
+
from efficient_model import MobileNetGRUModel, EfficientNetCNNModel, SqueezeNetTransformerModel
|
11 |
+
|
12 |
+
# Print library version information
|
13 |
+
print(f"\033[92mINFO\033[0m: PyTorch version: {torch.__version__}")
|
14 |
+
print(f"\033[92mINFO\033[0m: Torchaudio version: {torchaudio.__version__}")
|
15 |
+
print(f"\033[92mINFO\033[0m: Torchvision version: {torchvision.__version__}")
|
16 |
+
|
17 |
+
# Device selection
|
18 |
+
device = torch.device(
|
19 |
+
"cuda"
|
20 |
+
if torch.cuda.is_available()
|
21 |
+
else "mps" if torch.backends.mps.is_available() else "cpu"
|
22 |
+
)
|
23 |
+
print(f"\033[92mINFO\033[0m: Using device: {device}")
|
24 |
+
|
25 |
+
# Hyperparameters (using the best configuration from search)
|
26 |
+
batch_size = 4
|
27 |
+
epochs = 20
|
28 |
+
fc_hidden_size = 64
|
29 |
+
learning_rate = 0.0005
|
30 |
+
dropout_rate = 0.5
|
31 |
+
|
32 |
+
# Model save directory
|
33 |
+
os.makedirs("./models/", exist_ok=True)
|
34 |
+
|
35 |
+
|
36 |
+
class PreprocessedDataset(Dataset):
|
37 |
+
def __init__(self, data_dir):
|
38 |
+
self.data_dir = data_dir
|
39 |
+
self.samples = [
|
40 |
+
os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".pt")
|
41 |
+
]
|
42 |
+
|
43 |
+
def __len__(self):
|
44 |
+
return len(self.samples)
|
45 |
+
|
46 |
+
def __getitem__(self, idx):
|
47 |
+
sample_path = self.samples[idx]
|
48 |
+
mfcc, image, label = torch.load(sample_path)
|
49 |
+
return mfcc.float(), image.float(), label
|
50 |
+
|
51 |
+
|
52 |
+
def calculate_mae(outputs, labels):
|
53 |
+
"""Calculate Mean Absolute Error between outputs and labels"""
|
54 |
+
return torch.abs(outputs - labels).mean().item()
|
55 |
+
|
56 |
+
|
57 |
+
def evaluate_model(model, test_loader, criterion):
|
58 |
+
model.eval()
|
59 |
+
test_loss = 0.0
|
60 |
+
mae_sum = 0.0
|
61 |
+
all_predictions = []
|
62 |
+
all_labels = []
|
63 |
+
|
64 |
+
# For debugging
|
65 |
+
debug_samples = []
|
66 |
+
|
67 |
+
with torch.no_grad():
|
68 |
+
for mfcc, image, label in test_loader:
|
69 |
+
mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
|
70 |
+
output = model(mfcc, image)
|
71 |
+
label = label.view(-1, 1).float()
|
72 |
+
|
73 |
+
# Store debug samples (handling batch dimension properly)
|
74 |
+
if len(debug_samples) < 5:
|
75 |
+
# Extract individual samples from the batch
|
76 |
+
for i in range(min(len(output), 5 - len(debug_samples))):
|
77 |
+
debug_samples.append((output[i][0].item(), label[i][0].item()))
|
78 |
+
|
79 |
+
# Calculate MSE loss
|
80 |
+
loss = criterion(output, label)
|
81 |
+
test_loss += loss.item()
|
82 |
+
|
83 |
+
# Calculate MAE
|
84 |
+
mae = torch.abs(output - label).mean()
|
85 |
+
mae_sum += mae.item()
|
86 |
+
|
87 |
+
# Store predictions and labels for additional analysis
|
88 |
+
all_predictions.extend(output.cpu().numpy())
|
89 |
+
all_labels.extend(label.cpu().numpy())
|
90 |
+
|
91 |
+
avg_loss = test_loss / len(test_loader)
|
92 |
+
avg_mae = mae_sum / len(test_loader)
|
93 |
+
|
94 |
+
# Convert to numpy arrays for easier analysis
|
95 |
+
all_predictions = np.array(all_predictions).flatten()
|
96 |
+
all_labels = np.array(all_labels).flatten()
|
97 |
+
|
98 |
+
# Print debug samples
|
99 |
+
print("\nDEBUG SAMPLES (Prediction, Label):")
|
100 |
+
for i, (pred, label) in enumerate(debug_samples):
|
101 |
+
print(f"Sample {i+1}: Prediction = {pred:.4f}, Label = {label:.4f}, Difference = {abs(pred-label):.4f}")
|
102 |
+
|
103 |
+
return avg_loss, avg_mae, all_predictions, all_labels
|
104 |
+
|
105 |
+
|
106 |
+
def train_model(model_type):
|
107 |
+
try:
|
108 |
+
# Create model based on type
|
109 |
+
if model_type == "mobilenet_gru":
|
110 |
+
model = MobileNetGRUModel(
|
111 |
+
gru_hidden_size=32,
|
112 |
+
gru_layers=1,
|
113 |
+
fc_hidden_size=fc_hidden_size,
|
114 |
+
dropout_rate=dropout_rate
|
115 |
+
).to(device)
|
116 |
+
model_name = "MobileNetGRU"
|
117 |
+
elif model_type == "efficientnet_cnn":
|
118 |
+
model = EfficientNetCNNModel(
|
119 |
+
fc_hidden_size=fc_hidden_size,
|
120 |
+
dropout_rate=dropout_rate
|
121 |
+
).to(device)
|
122 |
+
model_name = "EfficientNetCNN"
|
123 |
+
elif model_type == "squeezenet_transformer":
|
124 |
+
model = SqueezeNetTransformerModel(
|
125 |
+
nhead=4,
|
126 |
+
dim_feedforward=128,
|
127 |
+
fc_hidden_size=fc_hidden_size,
|
128 |
+
dropout_rate=dropout_rate
|
129 |
+
).to(device)
|
130 |
+
model_name = "SqueezeNetTransformer"
|
131 |
+
else:
|
132 |
+
raise ValueError(f"Unknown model type: {model_type}")
|
133 |
+
|
134 |
+
# Data loading
|
135 |
+
data_dir = "./processed/"
|
136 |
+
dataset = PreprocessedDataset(data_dir)
|
137 |
+
n_samples = len(dataset)
|
138 |
+
|
139 |
+
# Check label range
|
140 |
+
all_labels = []
|
141 |
+
for i in range(min(10, len(dataset))):
|
142 |
+
_, _, label = dataset[i]
|
143 |
+
all_labels.append(label)
|
144 |
+
|
145 |
+
print("\nLABEL RANGE CHECK:")
|
146 |
+
print(f"Sample labels: {all_labels}")
|
147 |
+
print(f"Min label: {min(all_labels)}, Max label: {max(all_labels)}")
|
148 |
+
|
149 |
+
train_size = int(0.7 * n_samples)
|
150 |
+
val_size = int(0.2 * n_samples)
|
151 |
+
test_size = n_samples - train_size - val_size
|
152 |
+
|
153 |
+
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
|
154 |
+
dataset, [train_size, val_size, test_size]
|
155 |
+
)
|
156 |
+
|
157 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
158 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
159 |
+
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
160 |
+
|
161 |
+
# Loss function and optimizer
|
162 |
+
criterion = torch.nn.MSELoss()
|
163 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
164 |
+
|
165 |
+
# TensorBoard
|
166 |
+
writer = SummaryWriter(f"runs/{model_name}/")
|
167 |
+
global_step = 0
|
168 |
+
|
169 |
+
print(f"\033[92mINFO\033[0m: Training {model_name} model for {epochs} epochs")
|
170 |
+
print(f"\033[92mINFO\033[0m: Training samples: {len(train_dataset)}")
|
171 |
+
print(f"\033[92mINFO\033[0m: Validation samples: {len(val_dataset)}")
|
172 |
+
print(f"\033[92mINFO\033[0m: Test samples: {len(test_dataset)}")
|
173 |
+
print(f"\033[92mINFO\033[0m: Batch size: {batch_size}")
|
174 |
+
print(f"\033[92mINFO\033[0m: Learning rate: {learning_rate}")
|
175 |
+
print(f"\033[92mINFO\033[0m: Dropout rate: {dropout_rate}")
|
176 |
+
|
177 |
+
best_val_loss = float('inf')
|
178 |
+
best_model_path = None
|
179 |
+
|
180 |
+
# Calculate model size
|
181 |
+
model_size = sum(p.numel() for p in model.parameters()) / 1e6 # in millions
|
182 |
+
print(f"\033[92mINFO\033[0m: Model parameters: {model_size:.2f}M")
|
183 |
+
|
184 |
+
# Training loop
|
185 |
+
for epoch in range(epochs):
|
186 |
+
print(f"\033[92mINFO\033[0m: Training epoch ({epoch+1}/{epochs})")
|
187 |
+
|
188 |
+
model.train()
|
189 |
+
running_loss = 0.0
|
190 |
+
running_mae = 0.0
|
191 |
+
n_batches = 0
|
192 |
+
|
193 |
+
start_time = time.time()
|
194 |
+
|
195 |
+
try:
|
196 |
+
for mfcc, image, label in train_loader:
|
197 |
+
mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
|
198 |
+
|
199 |
+
optimizer.zero_grad()
|
200 |
+
output = model(mfcc, image)
|
201 |
+
label = label.view(-1, 1).float()
|
202 |
+
loss = criterion(output, label)
|
203 |
+
loss.backward()
|
204 |
+
optimizer.step()
|
205 |
+
|
206 |
+
running_loss += loss.item()
|
207 |
+
running_mae += calculate_mae(output, label)
|
208 |
+
n_batches += 1
|
209 |
+
|
210 |
+
writer.add_scalar("Training/Loss", loss.item(), global_step)
|
211 |
+
writer.add_scalar("Training/MAE", calculate_mae(output, label), global_step)
|
212 |
+
global_step += 1
|
213 |
+
except Exception as e:
|
214 |
+
print(f"\033[91mERR!\033[0m: {e}")
|
215 |
+
|
216 |
+
epoch_time = time.time() - start_time
|
217 |
+
|
218 |
+
# Validation phase
|
219 |
+
model.eval()
|
220 |
+
val_loss = 0.0
|
221 |
+
val_mae = 0.0
|
222 |
+
val_batches = 0
|
223 |
+
|
224 |
+
with torch.no_grad():
|
225 |
+
try:
|
226 |
+
for mfcc, image, label in val_loader:
|
227 |
+
mfcc, image, label = (
|
228 |
+
mfcc.to(device),
|
229 |
+
image.to(device),
|
230 |
+
label.to(device),
|
231 |
+
)
|
232 |
+
output = model(mfcc, image)
|
233 |
+
label = label.view(-1, 1).float()
|
234 |
+
|
235 |
+
# Calculate loss
|
236 |
+
loss = criterion(output, label)
|
237 |
+
val_loss += loss.item()
|
238 |
+
|
239 |
+
# Calculate MAE
|
240 |
+
val_mae += calculate_mae(output, label)
|
241 |
+
val_batches += 1
|
242 |
+
except Exception as e:
|
243 |
+
print(f"\033[91mERR!\033[0m: {e}")
|
244 |
+
|
245 |
+
avg_train_loss = running_loss / n_batches
|
246 |
+
avg_train_mae = running_mae / n_batches
|
247 |
+
avg_val_loss = val_loss / val_batches
|
248 |
+
avg_val_mae = val_mae / val_batches
|
249 |
+
|
250 |
+
# Record validation metrics
|
251 |
+
writer.add_scalar("Validation/Loss", avg_val_loss, epoch)
|
252 |
+
writer.add_scalar("Validation/MAE", avg_val_mae, epoch)
|
253 |
+
|
254 |
+
print(
|
255 |
+
f"Epoch [{epoch+1}/{epochs}], Time: {epoch_time:.2f}s, "
|
256 |
+
f"Train Loss: {avg_train_loss:.4f}, Train MAE: {avg_train_mae:.4f}, "
|
257 |
+
f"Val Loss: {avg_val_loss:.4f}, Val MAE: {avg_val_mae:.4f}"
|
258 |
+
)
|
259 |
+
|
260 |
+
# Save model checkpoint
|
261 |
+
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
262 |
+
model_path = f"models/{model_name}_model_{epoch+1}_{timestamp}.pt"
|
263 |
+
torch.save(model.state_dict(), model_path)
|
264 |
+
|
265 |
+
# Save the best model based on validation loss
|
266 |
+
if avg_val_loss < best_val_loss:
|
267 |
+
best_val_loss = avg_val_loss
|
268 |
+
best_model_path = model_path
|
269 |
+
print(f"\033[92mINFO\033[0m: New best model saved with validation loss: {best_val_loss:.4f}")
|
270 |
+
|
271 |
+
print(
|
272 |
+
f"\033[92mINFO\033[0m: Model checkpoint epoch [{epoch+1}/{epochs}] saved: {model_path}"
|
273 |
+
)
|
274 |
+
|
275 |
+
print(f"\033[92mINFO\033[0m: Training complete")
|
276 |
+
|
277 |
+
# Load the best model for testing
|
278 |
+
print(f"\033[92mINFO\033[0m: Loading best model from {best_model_path} for testing")
|
279 |
+
model.load_state_dict(torch.load(best_model_path))
|
280 |
+
|
281 |
+
# Evaluate on test set
|
282 |
+
test_loss, test_mae, predictions, labels = evaluate_model(model, test_loader, criterion)
|
283 |
+
|
284 |
+
# Calculate additional metrics
|
285 |
+
max_error = np.max(np.abs(predictions - labels))
|
286 |
+
min_error = np.min(np.abs(predictions - labels))
|
287 |
+
|
288 |
+
print("\n" + "="*50)
|
289 |
+
print(f"TEST RESULTS FOR {model_name}:")
|
290 |
+
print(f"Test Loss (MSE): {test_loss:.4f}")
|
291 |
+
print(f"Mean Absolute Error: {test_mae:.4f}")
|
292 |
+
print(f"Maximum Absolute Error: {max_error:.4f}")
|
293 |
+
print(f"Minimum Absolute Error: {min_error:.4f}")
|
294 |
+
|
295 |
+
# Add test results to TensorBoard
|
296 |
+
writer.add_scalar("Test/MSE", test_loss, 0)
|
297 |
+
writer.add_scalar("Test/MAE", test_mae, 0)
|
298 |
+
writer.add_scalar("Test/Max_Error", max_error, 0)
|
299 |
+
writer.add_scalar("Test/Min_Error", min_error, 0)
|
300 |
+
|
301 |
+
# Create a histogram of absolute errors
|
302 |
+
abs_errors = np.abs(predictions - labels)
|
303 |
+
writer.add_histogram("Test/Absolute_Errors", abs_errors, 0)
|
304 |
+
|
305 |
+
print("="*50)
|
306 |
+
|
307 |
+
# Final summary
|
308 |
+
print("\nTRAINING SUMMARY:")
|
309 |
+
print(f"Model: {model_name}")
|
310 |
+
print(f"Model Size: {model_size:.2f}M parameters")
|
311 |
+
print(f"Best Validation Loss: {best_val_loss:.4f}")
|
312 |
+
print(f"Final Test Loss: {test_loss:.4f}")
|
313 |
+
print(f"Final Test MAE: {test_mae:.4f}")
|
314 |
+
print(f"Best model saved at: {best_model_path}")
|
315 |
+
|
316 |
+
writer.close()
|
317 |
+
|
318 |
+
# Return metrics for comparison
|
319 |
+
return {
|
320 |
+
"model_name": model_name,
|
321 |
+
"model_size": model_size,
|
322 |
+
"val_loss": best_val_loss,
|
323 |
+
"test_loss": test_loss,
|
324 |
+
"test_mae": test_mae,
|
325 |
+
"model_path": best_model_path
|
326 |
+
}
|
327 |
+
|
328 |
+
except Exception as e:
|
329 |
+
print(f"\033[91mERR!\033[0m: Error training {model_type}: {e}")
|
330 |
+
# Return a placeholder result
|
331 |
+
return {
|
332 |
+
"model_name": model_type,
|
333 |
+
"model_size": 0,
|
334 |
+
"val_loss": float('inf'),
|
335 |
+
"test_loss": float('inf'),
|
336 |
+
"test_mae": float('inf'),
|
337 |
+
"model_path": None,
|
338 |
+
"error": str(e)
|
339 |
+
}
|
340 |
+
|
341 |
+
|
342 |
+
def test_cpu_inference(model_path, model_type):
|
343 |
+
"""Test CPU inference speed for the given model"""
|
344 |
+
# Create model based on type
|
345 |
+
if model_type == "mobilenet_gru":
|
346 |
+
model = MobileNetGRUModel(
|
347 |
+
gru_hidden_size=32,
|
348 |
+
gru_layers=1,
|
349 |
+
fc_hidden_size=fc_hidden_size,
|
350 |
+
dropout_rate=dropout_rate
|
351 |
+
)
|
352 |
+
model_name = "MobileNetGRU"
|
353 |
+
elif model_type == "efficientnet_cnn":
|
354 |
+
model = EfficientNetCNNModel(
|
355 |
+
fc_hidden_size=fc_hidden_size,
|
356 |
+
dropout_rate=dropout_rate
|
357 |
+
)
|
358 |
+
model_name = "EfficientNetCNN"
|
359 |
+
elif model_type == "squeezenet_transformer":
|
360 |
+
model = SqueezeNetTransformerModel(
|
361 |
+
nhead=4,
|
362 |
+
dim_feedforward=128,
|
363 |
+
fc_hidden_size=fc_hidden_size,
|
364 |
+
dropout_rate=dropout_rate
|
365 |
+
)
|
366 |
+
model_name = "SqueezeNetTransformer"
|
367 |
+
else:
|
368 |
+
raise ValueError(f"Unknown model type: {model_type}")
|
369 |
+
|
370 |
+
# Load model weights
|
371 |
+
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
372 |
+
model.eval()
|
373 |
+
|
374 |
+
# Create dummy input
|
375 |
+
dummy_mfcc = torch.randn(1, 10, 376) # Batch size 1, 10 time steps, 376 features
|
376 |
+
dummy_image = torch.randn(1, 3, 224, 224) # Batch size 1, 3 channels, 224x224 image
|
377 |
+
|
378 |
+
# Warm-up
|
379 |
+
for _ in range(10):
|
380 |
+
_ = model(dummy_mfcc, dummy_image)
|
381 |
+
|
382 |
+
# Measure inference time
|
383 |
+
num_runs = 100
|
384 |
+
start_time = time.time()
|
385 |
+
for _ in range(num_runs):
|
386 |
+
_ = model(dummy_mfcc, dummy_image)
|
387 |
+
end_time = time.time()
|
388 |
+
|
389 |
+
avg_time = (end_time - start_time) / num_runs
|
390 |
+
|
391 |
+
print(f"\n{model_name} CPU Inference Time:")
|
392 |
+
print(f"Average over {num_runs} runs: {avg_time*1000:.2f} ms")
|
393 |
+
|
394 |
+
return avg_time
|
395 |
+
|
396 |
+
|
397 |
+
if __name__ == "__main__":
|
398 |
+
parser = argparse.ArgumentParser(description="Train and evaluate efficient models")
|
399 |
+
parser.add_argument(
|
400 |
+
"--model",
|
401 |
+
type=str,
|
402 |
+
choices=["mobilenet_gru", "efficientnet_cnn", "squeezenet_transformer", "all"],
|
403 |
+
default="all",
|
404 |
+
help="Model architecture to train"
|
405 |
+
)
|
406 |
+
args = parser.parse_args()
|
407 |
+
|
408 |
+
results = []
|
409 |
+
|
410 |
+
if args.model == "all":
|
411 |
+
# Train all models
|
412 |
+
for model_type in ["mobilenet_gru", "efficientnet_cnn", "squeezenet_transformer"]:
|
413 |
+
print(f"\n\n{'='*50}")
|
414 |
+
print(f"TRAINING {model_type.upper()}")
|
415 |
+
print(f"{'='*50}\n")
|
416 |
+
result = train_model(model_type)
|
417 |
+
results.append(result)
|
418 |
+
|
419 |
+
# Test CPU inference
|
420 |
+
inference_time = test_cpu_inference(result["model_path"], model_type)
|
421 |
+
result["inference_time"] = inference_time
|
422 |
+
else:
|
423 |
+
# Train specific model
|
424 |
+
result = train_model(args.model)
|
425 |
+
results.append(result)
|
426 |
+
|
427 |
+
# Test CPU inference
|
428 |
+
inference_time = test_cpu_inference(result["model_path"], args.model)
|
429 |
+
result["inference_time"] = inference_time
|
430 |
+
|
431 |
+
# Compare results
|
432 |
+
print("\n\n" + "="*80)
|
433 |
+
print("MODEL COMPARISON")
|
434 |
+
print("="*80)
|
435 |
+
print(f"{'Model':<25} {'Size (M)':<10} {'Val Loss':<10} {'Test Loss':<10} {'Test MAE':<10} {'CPU Time (ms)':<15}")
|
436 |
+
print("-"*80)
|
437 |
+
|
438 |
+
for result in results:
|
439 |
+
print(f"{result['model_name']:<25} {result['model_size']:<10.2f} {result['val_loss']:<10.4f} "
|
440 |
+
f"{result['test_loss']:<10.4f} {result['test_mae']:<10.4f} {result['inference_time']*1000:<15.2f}")
|
441 |
+
|
442 |
+
print("="*80)
|
443 |
+
|
444 |
+
# Find best model
|
445 |
+
best_model = min(results, key=lambda x: x["test_mae"])
|
446 |
+
print(f"\nBEST MODEL: {best_model['model_name']}")
|
447 |
+
print(f"Test MAE: {best_model['test_mae']:.4f}")
|
448 |
+
print(f"CPU Inference Time: {best_model['inference_time']*1000:.2f} ms")
|
449 |
+
print(f"Model Path: {best_model['model_path']}")
|