Xalphinions commited on
Commit
5900417
·
verified ·
1 Parent(s): a14089e

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ temp/temp_audio.wav filter=lfs diff=lfs merge=lfs -text
37
+ temp/temp_image.jpg filter=lfs diff=lfs merge=lfs -text
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Watermelon2
3
- emoji: 🚀
4
- colorFrom: yellow
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.23.3
8
  app_file: app.py
9
- pinned: false
 
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: watermelon2
 
 
 
 
 
3
  app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 5.9.1
6
  ---
 
 
__pycache__/infer_watermelon.cpython-310.pyc ADDED
Binary file (4.39 kB). View file
 
__pycache__/train_watermelon.cpython-310.pyc ADDED
Binary file (6.74 kB). View file
 
app.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import numpy as np
5
+ import gradio as gr
6
+ import torchaudio
7
+ import torchvision
8
+
9
+ # Import Gradio Spaces GPU decorator
10
+ try:
11
+ from gradio import spaces
12
+ HAS_SPACES = True
13
+ print("\033[92mINFO\033[0m: Gradio Spaces detected, GPU acceleration will be enabled")
14
+ except ImportError:
15
+ HAS_SPACES = False
16
+ print("\033[93mWARN\033[0m: gradio.spaces not available, running without GPU optimization")
17
+
18
+ # Add parent directory to path to import preprocess functions
19
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
20
+
21
+ # Import functions from infer_watermelon.py and train_watermelon for the model
22
+ from train_watermelon import WatermelonModel
23
+
24
+ # Modified version of process_audio_data specifically for the app to handle various tensor shapes
25
+ def app_process_audio_data(waveform, sample_rate):
26
+ """Modified version of process_audio_data for the app that handles different tensor dimensions"""
27
+ try:
28
+ print(f"\033[92mDEBUG\033[0m: Processing audio - Initial shape: {waveform.shape}, Sample rate: {sample_rate}")
29
+
30
+ # Handle different tensor dimensions
31
+ if waveform.dim() == 3:
32
+ print(f"\033[92mDEBUG\033[0m: Found 3D tensor, converting to 2D")
33
+ # For 3D tensor, take the first item (batch dimension)
34
+ waveform = waveform[0]
35
+
36
+ if waveform.dim() == 2:
37
+ # Use the first channel for stereo audio
38
+ waveform = waveform[0]
39
+ print(f"\033[92mDEBUG\033[0m: Using first channel, new shape: {waveform.shape}")
40
+
41
+ # Resample to 16kHz if needed
42
+ resample_rate = 16000
43
+ if sample_rate != resample_rate:
44
+ print(f"\033[92mDEBUG\033[0m: Resampling from {sample_rate}Hz to {resample_rate}Hz")
45
+ waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=resample_rate)(waveform)
46
+
47
+ # Ensure 3 seconds of audio
48
+ if waveform.size(0) < 3 * resample_rate:
49
+ print(f"\033[92mDEBUG\033[0m: Padding audio from {waveform.size(0)} to {3 * resample_rate} samples")
50
+ waveform = torch.nn.functional.pad(waveform, (0, 3 * resample_rate - waveform.size(0)))
51
+ else:
52
+ print(f"\033[92mDEBUG\033[0m: Trimming audio from {waveform.size(0)} to {3 * resample_rate} samples")
53
+ waveform = waveform[: 3 * resample_rate]
54
+
55
+ # Apply MFCC transformation
56
+ print(f"\033[92mDEBUG\033[0m: Applying MFCC transformation")
57
+ mfcc_transform = torchaudio.transforms.MFCC(
58
+ sample_rate=resample_rate,
59
+ n_mfcc=13,
60
+ melkwargs={
61
+ "n_fft": 256,
62
+ "win_length": 256,
63
+ "hop_length": 128,
64
+ "n_mels": 40,
65
+ }
66
+ )
67
+
68
+ mfcc = mfcc_transform(waveform)
69
+ print(f"\033[92mDEBUG\033[0m: MFCC output shape: {mfcc.shape}")
70
+
71
+ return mfcc
72
+ except Exception as e:
73
+ import traceback
74
+ print(f"\033[91mERR!\033[0m: Error in audio processing: {e}")
75
+ print(traceback.format_exc())
76
+ return None
77
+
78
+ # Similarly for images, but let's import the original one
79
+ from preprocess import process_image_data
80
+
81
+ # Define prediction function
82
+ def predict_sweetness(audio, image, model_path):
83
+ """Predict sweetness of a watermelon from audio and image input"""
84
+ try:
85
+ # Now check CUDA availability inside the GPU-decorated function
86
+ if torch.cuda.is_available():
87
+ device = torch.device("cuda")
88
+ print(f"\033[92mINFO\033[0m: CUDA is available. Using device: {device}")
89
+ else:
90
+ device = torch.device("cpu")
91
+ print(f"\033[92mINFO\033[0m: CUDA is not available. Using device: {device}")
92
+
93
+ # Load model inside the function to ensure it's on the correct device
94
+ model = WatermelonModel().to(device)
95
+ model.load_state_dict(torch.load(model_path, map_location=device))
96
+ model.eval()
97
+ print(f"\033[92mINFO\033[0m: Loaded model from {model_path}")
98
+
99
+ # Debug information about input types
100
+ print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
101
+ print(f"\033[92mDEBUG\033[0m: Audio input shape/length: {len(audio)}")
102
+ print(f"\033[92mDEBUG\033[0m: Image input type: {type(image)}")
103
+ if isinstance(image, np.ndarray):
104
+ print(f"\033[92mDEBUG\033[0m: Image input shape: {image.shape}")
105
+
106
+ # Handle different audio input formats
107
+ if isinstance(audio, tuple) and len(audio) == 2:
108
+ # Standard Gradio format: (sample_rate, audio_data)
109
+ sample_rate, audio_data = audio
110
+ print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}")
111
+ print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
112
+ elif isinstance(audio, tuple) and len(audio) > 2:
113
+ # Sometimes Gradio returns (sample_rate, audio_data, other_info...)
114
+ sample_rate, audio_data = audio[0], audio[-1]
115
+ print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}")
116
+ print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
117
+ elif isinstance(audio, str):
118
+ # Direct path to audio file
119
+ audio_data, sample_rate = torchaudio.load(audio)
120
+ print(f"\033[92mDEBUG\033[0m: Loaded audio from path with shape: {audio_data.shape}")
121
+ else:
122
+ return f"Error: Unsupported audio format. Got {type(audio)}"
123
+
124
+ # Create a temporary file path for the audio and image
125
+ temp_dir = "temp"
126
+ os.makedirs(temp_dir, exist_ok=True)
127
+
128
+ temp_audio_path = os.path.join(temp_dir, "temp_audio.wav")
129
+ temp_image_path = os.path.join(temp_dir, "temp_image.jpg")
130
+
131
+ # Import necessary libraries
132
+ from PIL import Image
133
+
134
+ # Audio handling - direct processing from the data in memory
135
+ if isinstance(audio_data, np.ndarray):
136
+ # Convert numpy array to tensor
137
+ print(f"\033[92mDEBUG\033[0m: Converting numpy audio with shape {audio_data.shape} to tensor")
138
+ audio_tensor = torch.tensor(audio_data).float()
139
+
140
+ # Handle different audio dimensions
141
+ if audio_data.ndim == 1:
142
+ # Single channel audio
143
+ audio_tensor = audio_tensor.unsqueeze(0)
144
+ elif audio_data.ndim == 2:
145
+ # Ensure channels are first dimension
146
+ if audio_data.shape[0] > audio_data.shape[1]:
147
+ # More rows than columns, probably (samples, channels)
148
+ audio_tensor = torch.tensor(audio_data.T).float()
149
+ else:
150
+ # Already a tensor
151
+ audio_tensor = audio_data.float()
152
+
153
+ print(f"\033[92mDEBUG\033[0m: Audio tensor shape before processing: {audio_tensor.shape}")
154
+
155
+ # Skip saving/loading and process directly
156
+ mfcc = app_process_audio_data(audio_tensor, sample_rate)
157
+ print(f"\033[92mDEBUG\033[0m: MFCC tensor shape after processing: {mfcc.shape if mfcc is not None else None}")
158
+
159
+ # Image handling
160
+ if isinstance(image, np.ndarray):
161
+ print(f"\033[92mDEBUG\033[0m: Converting numpy image with shape {image.shape} to PIL")
162
+ pil_image = Image.fromarray(image)
163
+ pil_image.save(temp_image_path)
164
+ print(f"\033[92mDEBUG\033[0m: Saved image to {temp_image_path}")
165
+ elif isinstance(image, str):
166
+ # If image is already a path
167
+ temp_image_path = image
168
+ print(f"\033[92mDEBUG\033[0m: Using provided image path: {temp_image_path}")
169
+ else:
170
+ return f"Error: Unsupported image format. Got {type(image)}"
171
+
172
+ # Process image
173
+ print(f"\033[92mDEBUG\033[0m: Loading and preprocessing image from {temp_image_path}")
174
+ image_tensor = torchvision.io.read_image(temp_image_path)
175
+ print(f"\033[92mDEBUG\033[0m: Loaded image shape: {image_tensor.shape}")
176
+ image_tensor = image_tensor.float()
177
+ processed_image = process_image_data(image_tensor)
178
+ print(f"\033[92mDEBUG\033[0m: Processed image shape: {processed_image.shape if processed_image is not None else None}")
179
+
180
+ # Add batch dimension for inference and move to device
181
+ if mfcc is not None:
182
+ mfcc = mfcc.unsqueeze(0).to(device)
183
+ print(f"\033[92mDEBUG\033[0m: Final MFCC shape with batch dimension: {mfcc.shape}")
184
+
185
+ if processed_image is not None:
186
+ processed_image = processed_image.unsqueeze(0).to(device)
187
+ print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}")
188
+
189
+ # Run inference
190
+ print(f"\033[92mDEBUG\033[0m: Running inference on device: {device}")
191
+ if mfcc is not None and processed_image is not None:
192
+ with torch.no_grad():
193
+ sweetness = model(mfcc, processed_image)
194
+ print(f"\033[92mDEBUG\033[0m: Prediction successful: {sweetness.item()}")
195
+ else:
196
+ return "Error: Failed to process inputs. Please check the debug logs."
197
+
198
+ # Format the result
199
+ if sweetness is not None:
200
+ result = f"Predicted Sweetness: {sweetness.item():.2f}/13"
201
+
202
+ # Add a qualitative description
203
+ if sweetness.item() < 9:
204
+ result += "\n\nThis watermelon is not very sweet. You might want to choose another one."
205
+ elif sweetness.item() < 10:
206
+ result += "\n\nThis watermelon has moderate sweetness."
207
+ elif sweetness.item() < 11:
208
+ result += "\n\nThis watermelon is sweet! A good choice."
209
+ else:
210
+ result += "\n\nThis watermelon is very sweet! Excellent choice!"
211
+
212
+ return result
213
+ else:
214
+ return "Error: Could not predict sweetness. Please try again with different inputs."
215
+
216
+ except Exception as e:
217
+ import traceback
218
+ error_msg = f"Error: {str(e)}\n\n"
219
+ error_msg += traceback.format_exc()
220
+ print(f"\033[91mERR!\033[0m: {error_msg}")
221
+ return error_msg
222
+
223
+ # Apply GPU decorator if available in Gradio Spaces environment
224
+ if HAS_SPACES:
225
+ predict_sweetness_gpu = spaces.GPU(predict_sweetness)
226
+ print("\033[92mINFO\033[0m: GPU optimization enabled for prediction function")
227
+ else:
228
+ predict_sweetness_gpu = predict_sweetness
229
+
230
+ def create_app(model_path):
231
+ """Create and launch the Gradio interface"""
232
+ # Define the prediction function with model path
233
+ def predict_fn(audio, image):
234
+ if HAS_SPACES:
235
+ # Use GPU-optimized function if available
236
+ return predict_sweetness_gpu(audio, image, model_path)
237
+ else:
238
+ # Use regular function otherwise
239
+ return predict_sweetness(audio, image, model_path)
240
+
241
+ # Create Gradio interface
242
+ with gr.Blocks(title="Watermelon Sweetness Predictor", theme=gr.themes.Soft()) as interface:
243
+ gr.Markdown("# 🍉 Watermelon Sweetness Predictor")
244
+ gr.Markdown("""
245
+ This app predicts the sweetness of a watermelon based on its sound and appearance.
246
+
247
+ ## Instructions:
248
+ 1. Upload or record an audio of tapping the watermelon
249
+ 2. Upload or capture an image of the watermelon
250
+ 3. Click 'Predict' to get the sweetness estimation
251
+ """)
252
+
253
+ with gr.Row():
254
+ with gr.Column():
255
+ audio_input = gr.Audio(label="Upload or Record Audio", type="numpy")
256
+ image_input = gr.Image(label="Upload or Capture Image")
257
+ submit_btn = gr.Button("Predict Sweetness", variant="primary")
258
+
259
+ with gr.Column():
260
+ output = gr.Textbox(label="Prediction Results", lines=6)
261
+
262
+ submit_btn.click(
263
+ fn=predict_fn,
264
+ inputs=[audio_input, image_input],
265
+ outputs=output
266
+ )
267
+
268
+ gr.Markdown("""
269
+ ## How it works
270
+
271
+ The app uses a deep learning model that combines:
272
+ - Audio analysis using MFCC features and LSTM neural network
273
+ - Image analysis using ResNet-50 convolutional neural network
274
+
275
+ The model was trained on a dataset of watermelons with known sweetness values.
276
+
277
+ ## Tips for best results
278
+ - For audio: Tap the watermelon with your knuckle and record the sound
279
+ - For image: Take a clear photo of the whole watermelon in good lighting
280
+ """)
281
+
282
+ return interface
283
+
284
+ if __name__ == "__main__":
285
+ import argparse
286
+
287
+ parser = argparse.ArgumentParser(description="Watermelon Sweetness Prediction App")
288
+ parser.add_argument(
289
+ "--model_path",
290
+ type=str,
291
+ default="models/watermelon_model_final.pt",
292
+ help="Path to the trained model file"
293
+ )
294
+ parser.add_argument(
295
+ "--share",
296
+ action="store_true",
297
+ help="Create a shareable link for the app"
298
+ )
299
+ parser.add_argument(
300
+ "--debug",
301
+ action="store_true",
302
+ help="Enable verbose debug output"
303
+ )
304
+
305
+ args = parser.parse_args()
306
+
307
+ if args.debug:
308
+ print(f"\033[92mINFO\033[0m: Debug mode enabled")
309
+
310
+ # Check if model exists
311
+ if not os.path.exists(args.model_path):
312
+ print(f"\033[91mERR!\033[0m: Model not found at {args.model_path}")
313
+ print("\033[92mINFO\033[0m: Please train a model first or provide a valid model path")
314
+ sys.exit(1)
315
+
316
+ # Create and launch the app
317
+ app = create_app(args.model_path)
318
+ app.launch(share=args.share)
app_local_backup.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import numpy as np
5
+ import gradio as gr
6
+ import torchaudio
7
+ import torchvision
8
+
9
+ # Add parent directory to path to import preprocess functions
10
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
11
+
12
+ # Import functions from infer_watermelon.py
13
+ from infer_watermelon import load_model
14
+
15
+ # Modified version of process_audio_data specifically for the app to handle various tensor shapes
16
+ def app_process_audio_data(waveform, sample_rate):
17
+ """Modified version of process_audio_data for the app that handles different tensor dimensions"""
18
+ try:
19
+ print(f"\033[92mDEBUG\033[0m: Processing audio - Initial shape: {waveform.shape}, Sample rate: {sample_rate}")
20
+
21
+ # Handle different tensor dimensions
22
+ if waveform.dim() == 3:
23
+ print(f"\033[92mDEBUG\033[0m: Found 3D tensor, converting to 2D")
24
+ # For 3D tensor, take the first item (batch dimension)
25
+ waveform = waveform[0]
26
+
27
+ if waveform.dim() == 2:
28
+ # Use the first channel for stereo audio
29
+ waveform = waveform[0]
30
+ print(f"\033[92mDEBUG\033[0m: Using first channel, new shape: {waveform.shape}")
31
+
32
+ # Resample to 16kHz if needed
33
+ resample_rate = 16000
34
+ if sample_rate != resample_rate:
35
+ print(f"\033[92mDEBUG\033[0m: Resampling from {sample_rate}Hz to {resample_rate}Hz")
36
+ waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=resample_rate)(waveform)
37
+
38
+ # Ensure 3 seconds of audio
39
+ if waveform.size(0) < 3 * resample_rate:
40
+ print(f"\033[92mDEBUG\033[0m: Padding audio from {waveform.size(0)} to {3 * resample_rate} samples")
41
+ waveform = torch.nn.functional.pad(waveform, (0, 3 * resample_rate - waveform.size(0)))
42
+ else:
43
+ print(f"\033[92mDEBUG\033[0m: Trimming audio from {waveform.size(0)} to {3 * resample_rate} samples")
44
+ waveform = waveform[: 3 * resample_rate]
45
+
46
+ # Apply MFCC transformation
47
+ print(f"\033[92mDEBUG\033[0m: Applying MFCC transformation")
48
+ mfcc_transform = torchaudio.transforms.MFCC(
49
+ sample_rate=resample_rate,
50
+ n_mfcc=13,
51
+ melkwargs={
52
+ "n_fft": 256,
53
+ "win_length": 256,
54
+ "hop_length": 128,
55
+ "n_mels": 40,
56
+ }
57
+ )
58
+
59
+ mfcc = mfcc_transform(waveform)
60
+ print(f"\033[92mDEBUG\033[0m: MFCC output shape: {mfcc.shape}")
61
+
62
+ return mfcc
63
+ except Exception as e:
64
+ import traceback
65
+ print(f"\033[91mERR!\033[0m: Error in audio processing: {e}")
66
+ print(traceback.format_exc())
67
+ return None
68
+
69
+ # Similarly for images, but let's import the original one
70
+ from preprocess import process_image_data
71
+
72
+ def init_model(model_path):
73
+ """Initialize the model for inference"""
74
+ model, device = load_model(model_path)
75
+ return model, device
76
+
77
+ def predict_sweetness(audio, image, model, device):
78
+ """Predict sweetness of a watermelon from audio and image input"""
79
+ try:
80
+ # Debug information about input types
81
+ print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
82
+ print(f"\033[92mDEBUG\033[0m: Audio input shape/length: {len(audio)}")
83
+ print(f"\033[92mDEBUG\033[0m: Image input type: {type(image)}")
84
+ if isinstance(image, np.ndarray):
85
+ print(f"\033[92mDEBUG\033[0m: Image input shape: {image.shape}")
86
+
87
+ # Handle different audio input formats
88
+ if isinstance(audio, tuple) and len(audio) == 2:
89
+ # Standard Gradio format: (sample_rate, audio_data)
90
+ sample_rate, audio_data = audio
91
+ print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}")
92
+ print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
93
+ elif isinstance(audio, tuple) and len(audio) > 2:
94
+ # Sometimes Gradio returns (sample_rate, audio_data, other_info...)
95
+ sample_rate, audio_data = audio[0], audio[-1]
96
+ print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}")
97
+ print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
98
+ elif isinstance(audio, str):
99
+ # Direct path to audio file
100
+ import torchaudio
101
+ audio_data, sample_rate = torchaudio.load(audio)
102
+ print(f"\033[92mDEBUG\033[0m: Loaded audio from path with shape: {audio_data.shape}")
103
+ else:
104
+ return f"Error: Unsupported audio format. Got {type(audio)}"
105
+
106
+ # Create a temporary file path for the audio and image
107
+ temp_dir = "temp"
108
+ os.makedirs(temp_dir, exist_ok=True)
109
+
110
+ temp_audio_path = os.path.join(temp_dir, "temp_audio.wav")
111
+ temp_image_path = os.path.join(temp_dir, "temp_image.jpg")
112
+
113
+ # Import necessary libraries
114
+ import torchaudio
115
+ import torchvision
116
+ import torchvision.transforms.functional as F
117
+ from PIL import Image
118
+
119
+ # Audio handling - direct processing from the data in memory
120
+ if isinstance(audio_data, np.ndarray):
121
+ # Convert numpy array to tensor
122
+ print(f"\033[92mDEBUG\033[0m: Converting numpy audio with shape {audio_data.shape} to tensor")
123
+ audio_tensor = torch.tensor(audio_data).float()
124
+
125
+ # Handle different audio dimensions
126
+ if audio_data.ndim == 1:
127
+ # Single channel audio
128
+ audio_tensor = audio_tensor.unsqueeze(0)
129
+ elif audio_data.ndim == 2:
130
+ # Ensure channels are first dimension
131
+ if audio_data.shape[0] > audio_data.shape[1]:
132
+ # More rows than columns, probably (samples, channels)
133
+ audio_tensor = torch.tensor(audio_data.T).float()
134
+ else:
135
+ # Already a tensor
136
+ audio_tensor = audio_data.float()
137
+
138
+ print(f"\033[92mDEBUG\033[0m: Audio tensor shape before processing: {audio_tensor.shape}")
139
+
140
+ # Skip saving/loading and process directly
141
+ mfcc = app_process_audio_data(audio_tensor, sample_rate)
142
+ print(f"\033[92mDEBUG\033[0m: MFCC tensor shape after processing: {mfcc.shape if mfcc is not None else None}")
143
+
144
+ # Image handling
145
+ if isinstance(image, np.ndarray):
146
+ print(f"\033[92mDEBUG\033[0m: Converting numpy image with shape {image.shape} to PIL")
147
+ pil_image = Image.fromarray(image)
148
+ pil_image.save(temp_image_path)
149
+ print(f"\033[92mDEBUG\033[0m: Saved image to {temp_image_path}")
150
+ elif isinstance(image, str):
151
+ # If image is already a path
152
+ temp_image_path = image
153
+ print(f"\033[92mDEBUG\033[0m: Using provided image path: {temp_image_path}")
154
+ else:
155
+ return f"Error: Unsupported image format. Got {type(image)}"
156
+
157
+ # Process image
158
+ print(f"\033[92mDEBUG\033[0m: Loading and preprocessing image from {temp_image_path}")
159
+ image_tensor = torchvision.io.read_image(temp_image_path)
160
+ print(f"\033[92mDEBUG\033[0m: Loaded image shape: {image_tensor.shape}")
161
+ image_tensor = image_tensor.float()
162
+ processed_image = process_image_data(image_tensor)
163
+ print(f"\033[92mDEBUG\033[0m: Processed image shape: {processed_image.shape if processed_image is not None else None}")
164
+
165
+ # Add batch dimension for inference
166
+ if mfcc is not None:
167
+ mfcc = mfcc.unsqueeze(0).to(device)
168
+ print(f"\033[92mDEBUG\033[0m: Final MFCC shape with batch dimension: {mfcc.shape}")
169
+
170
+ if processed_image is not None:
171
+ processed_image = processed_image.unsqueeze(0).to(device)
172
+ print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}")
173
+
174
+ # Run inference
175
+ print(f"\033[92mDEBUG\033[0m: Running inference")
176
+ if mfcc is not None and processed_image is not None:
177
+ with torch.no_grad():
178
+ sweetness = model(mfcc, processed_image)
179
+ print(f"\033[92mDEBUG\033[0m: Prediction successful: {sweetness.item()}")
180
+ else:
181
+ return "Error: Failed to process inputs. Please check the debug logs."
182
+
183
+ # Format the result
184
+ if sweetness is not None:
185
+ result = f"Predicted Sweetness: {sweetness.item():.2f}/13"
186
+
187
+ # Add a qualitative description
188
+ if sweetness.item() < 9:
189
+ result += "\n\nThis watermelon is not very sweet. You might want to choose another one."
190
+ elif sweetness.item() < 10:
191
+ result += "\n\nThis watermelon has moderate sweetness."
192
+ elif sweetness.item() < 11:
193
+ result += "\n\nThis watermelon is sweet! A good choice."
194
+ else:
195
+ result += "\n\nThis watermelon is very sweet! Excellent choice!"
196
+
197
+ return result
198
+ else:
199
+ return "Error: Could not predict sweetness. Please try again with different inputs."
200
+
201
+ except Exception as e:
202
+ import traceback
203
+ error_msg = f"Error: {str(e)}\n\n"
204
+ error_msg += traceback.format_exc()
205
+ print(f"\033[91mERR!\033[0m: {error_msg}")
206
+ return error_msg
207
+
208
+ def create_app(model_path):
209
+ """Create and launch the Gradio interface"""
210
+ # Initialize model
211
+ model, device = init_model(model_path)
212
+
213
+ # Define the prediction function with model and device
214
+ def predict_fn(audio, image):
215
+ return predict_sweetness(audio, image, model, device)
216
+
217
+ # Create Gradio interface
218
+ with gr.Blocks(title="Watermelon Sweetness Predictor") as interface:
219
+ gr.Markdown("# 🍉 Watermelon Sweetness Predictor")
220
+ gr.Markdown("""
221
+ This app predicts the sweetness of a watermelon based on its sound and appearance.
222
+
223
+ ## Instructions:
224
+ 1. Upload or record an audio of tapping the watermelon
225
+ 2. Upload or capture an image of the watermelon
226
+ 3. Click 'Submit' to get the predicted sweetness
227
+ """)
228
+
229
+ with gr.Row():
230
+ with gr.Column():
231
+ audio_input = gr.Audio(label="Upload or Record Audio", type="numpy")
232
+ image_input = gr.Image(label="Upload or Capture Image")
233
+ submit_btn = gr.Button("Predict Sweetness", variant="primary")
234
+
235
+ with gr.Column():
236
+ output = gr.Textbox(label="Prediction Results", lines=6)
237
+
238
+ submit_btn.click(
239
+ fn=predict_fn,
240
+ inputs=[audio_input, image_input],
241
+ outputs=output
242
+ )
243
+
244
+ gr.Markdown("""
245
+ ## How it works
246
+
247
+ The app uses a deep learning model that combines:
248
+ - Audio analysis using MFCC features and LSTM neural network
249
+ - Image analysis using ResNet-50 convolutional neural network
250
+
251
+ The model was trained on a dataset of watermelons with known sweetness values.
252
+ """)
253
+
254
+ return interface
255
+
256
+ if __name__ == "__main__":
257
+ import argparse
258
+
259
+ parser = argparse.ArgumentParser(description="Watermelon Sweetness Prediction App")
260
+ parser.add_argument(
261
+ "--model_path",
262
+ type=str,
263
+ default="models/watermelon_model_final.pt",
264
+ help="Path to the trained model file"
265
+ )
266
+ parser.add_argument(
267
+ "--share",
268
+ action="store_true",
269
+ help="Create a shareable link for the app"
270
+ )
271
+ parser.add_argument(
272
+ "--debug",
273
+ action="store_true",
274
+ help="Enable verbose debug output"
275
+ )
276
+
277
+ args = parser.parse_args()
278
+
279
+ if args.debug:
280
+ print(f"\033[92mINFO\033[0m: Debug mode enabled")
281
+
282
+ # Check if model exists
283
+ if not os.path.exists(args.model_path):
284
+ print(f"\033[91mERR!\033[0m: Model not found at {args.model_path}")
285
+ print("\033[92mINFO\033[0m: Please train a model first or provide a valid model path")
286
+ sys.exit(1)
287
+
288
+ # Create and launch the app
289
+ app = create_app(args.model_path)
290
+ app.launch(share=args.share)
infer_watermelon.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import torchaudio
5
+ import torchvision
6
+ import argparse
7
+ import numpy as np
8
+
9
+ # Add parent directory to path to import the preprocess functions
10
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
11
+ from preprocess import process_audio_data, process_image_data
12
+
13
+ # Import the model definition
14
+ from train_watermelon import WatermelonModel
15
+
16
+ def load_model(model_path):
17
+ """Load a trained model from the given path"""
18
+ device = torch.device(
19
+ "cuda" if torch.cuda.is_available()
20
+ else "mps" if torch.backends.mps.is_available()
21
+ else "cpu"
22
+ )
23
+ print(f"\033[92mINFO\033[0m: Using device: {device}")
24
+
25
+ model = WatermelonModel().to(device)
26
+ model.load_state_dict(torch.load(model_path, map_location=device))
27
+ model.eval()
28
+ print(f"\033[92mINFO\033[0m: Loaded model from {model_path}")
29
+
30
+ return model, device
31
+
32
+ def infer_single_sample(audio_path, image_path, model, device):
33
+ """Run inference on a single sample"""
34
+ # Load and process audio
35
+ try:
36
+ waveform, sample_rate = torchaudio.load(audio_path)
37
+ mfcc = process_audio_data(waveform, sample_rate).to(device)
38
+
39
+ # Load and process image
40
+ image = torchvision.io.read_image(image_path)
41
+ image = image.float()
42
+ processed_image = process_image_data(image).to(device)
43
+
44
+ # Add batch dimension
45
+ mfcc = mfcc.unsqueeze(0)
46
+ processed_image = processed_image.unsqueeze(0)
47
+
48
+ # Run inference
49
+ with torch.no_grad():
50
+ sweetness = model(mfcc, processed_image)
51
+
52
+ return sweetness.item()
53
+ except Exception as e:
54
+ print(f"\033[91mERR!\033[0m: Error in inference: {e}")
55
+ return None
56
+
57
+ def infer_from_directory(data_dir, model_path, output_file=None, num_samples=None):
58
+ """Run inference on samples from the dataset directory"""
59
+ # Load model
60
+ model, device = load_model(model_path)
61
+
62
+ # Collect all samples
63
+ samples = []
64
+ results = []
65
+
66
+ print(f"\033[92mINFO\033[0m: Reading samples from {data_dir}")
67
+
68
+ # Walk through the directory structure
69
+ for sweetness_dir in os.listdir(data_dir):
70
+ try:
71
+ sweetness = float(sweetness_dir)
72
+ sweetness_path = os.path.join(data_dir, sweetness_dir)
73
+
74
+ if os.path.isdir(sweetness_path):
75
+ for id_dir in os.listdir(sweetness_path):
76
+ id_path = os.path.join(sweetness_path, id_dir)
77
+
78
+ if os.path.isdir(id_path):
79
+ audio_file = os.path.join(id_path, f"{id_dir}.wav")
80
+ image_file = os.path.join(id_path, f"{id_dir}.jpg")
81
+
82
+ if os.path.exists(audio_file) and os.path.exists(image_file):
83
+ samples.append((audio_file, image_file, sweetness, id_dir))
84
+ except ValueError:
85
+ # Skip directories that are not valid sweetness values
86
+ continue
87
+
88
+ # Limit the number of samples if specified
89
+ if num_samples is not None and num_samples > 0:
90
+ samples = samples[:num_samples]
91
+
92
+ print(f"\033[92mINFO\033[0m: Running inference on {len(samples)} samples")
93
+
94
+ # Run inference on each sample
95
+ for i, (audio_file, image_file, true_sweetness, sample_id) in enumerate(samples):
96
+ print(f"\033[92mINFO\033[0m: Processing sample {i+1}/{len(samples)}: {sample_id}")
97
+
98
+ predicted_sweetness = infer_single_sample(audio_file, image_file, model, device)
99
+
100
+ if predicted_sweetness is not None:
101
+ error = abs(predicted_sweetness - true_sweetness)
102
+ results.append({
103
+ 'sample_id': sample_id,
104
+ 'true_sweetness': true_sweetness,
105
+ 'predicted_sweetness': predicted_sweetness,
106
+ 'error': error
107
+ })
108
+ print(f" Sample ID: {sample_id}")
109
+ print(f" True sweetness: {true_sweetness:.2f}")
110
+ print(f" Predicted sweetness: {predicted_sweetness:.2f}")
111
+ print(f" Error: {error:.2f}")
112
+
113
+ # Calculate mean absolute error
114
+ if results:
115
+ mae = np.mean([result['error'] for result in results])
116
+ print(f"\033[92mINFO\033[0m: Mean Absolute Error: {mae:.4f}")
117
+
118
+ # Save results to file if specified
119
+ if output_file and results:
120
+ with open(output_file, 'w') as f:
121
+ f.write("sample_id,true_sweetness,predicted_sweetness,error\n")
122
+ for result in results:
123
+ f.write(f"{result['sample_id']},{result['true_sweetness']:.2f},{result['predicted_sweetness']:.2f},{result['error']:.2f}\n")
124
+ print(f"\033[92mINFO\033[0m: Results saved to {output_file}")
125
+
126
+ return results
127
+
128
+ def main():
129
+ parser = argparse.ArgumentParser(description="Watermelon Sweetness Inference")
130
+ parser.add_argument("--model_path", type=str, required=True, help="Path to the trained model file")
131
+ parser.add_argument("--data_dir", type=str, default="../cleaned", help="Path to the cleaned dataset directory")
132
+ parser.add_argument("--output_file", type=str, help="Path to save inference results (CSV)")
133
+ parser.add_argument("--num_samples", type=int, help="Number of samples to run inference on (default: all)")
134
+ parser.add_argument("--audio_path", type=str, help="Path to a single audio file for inference")
135
+ parser.add_argument("--image_path", type=str, help="Path to a single image file for inference")
136
+
137
+ args = parser.parse_args()
138
+
139
+ # Check if single sample inference or dataset inference
140
+ if args.audio_path and args.image_path:
141
+ # Single sample inference
142
+ model, device = load_model(args.model_path)
143
+ sweetness = infer_single_sample(args.audio_path, args.image_path, model, device)
144
+ print(f"Predicted sweetness: {sweetness:.2f}")
145
+ else:
146
+ # Dataset inference
147
+ infer_from_directory(args.data_dir, args.model_path, args.output_file, args.num_samples)
148
+
149
+ if __name__ == "__main__":
150
+ main()
models/model_1_20250406-064126.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5df632222fa87e09e635f90e5cce14bdd9fd34b442bf18daaf13e54dedfed132
3
+ size 96095572
models/model_1_20250406-064635.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02999bd33592de717dc1ec8054dc570193074c3f25a7283b3daa580b727b7134
3
+ size 96095572
models/model_2_20250406-065053.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80f999a1540c42ed74491692aa66c3b5a6171f972bdf47c9d52556fe1673c8dd
3
+ size 96095572
models/watermelon_model_final.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:086780aee9897ea51a6b0da0fed8aaa61ae97563c70a8c6577849ef9a0220edb
3
+ size 96095241
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchaudio>=2.0.0
3
+ torchvision>=0.15.0
4
+ gradio>=3.50.0
5
+ numpy>=1.20.0
6
+ pillow>=9.0.0
runs/events.out.tfevents.1743920786.vm-jinzq.2059144.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f3e44b329373e1b4e8233833c35e382cf1c548c03a449e237c89b4c0333af42f
3
+ size 88
runs/events.out.tfevents.1743920828.vm-jinzq.2059396.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1396659d9fdb300ed3bf8ee38bf6605c634376c36a3e47e8398968eb9ea4b6ea
3
+ size 88
runs/events.out.tfevents.1743921401.jzqdebug-c245a8-job-84fn7.812.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4346aad295036526c9dadae4a84f18cab863a1ec43f13b0d5b32566b5361179
3
+ size 14985
runs/events.out.tfevents.1743921735.jzqdebug-c245a8-job-84fn7.1262.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ec7e16dece5b3f09408359a3a18fb40a87f23e02e1b16981ebb9ea9e463f6ef
3
+ size 7238
temp/temp_audio.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8da44f18fa66bb5db09dc6ef4ea542c5274d8b2a1d952efd1db1ceec7948ca44
3
+ size 1058488
temp/temp_image.jpg ADDED

Git LFS Details

  • SHA256: 88a3633370f2a04e0c41946cdcd6f63883eca31ae8534b8f4379d6e8b84a25f0
  • Pointer size: 131 Bytes
  • Size of remote file: 406 kB
train_watermelon.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import torch
4
+ import torchaudio
5
+ import torchvision
6
+ import numpy as np
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from torch.utils.tensorboard import SummaryWriter
9
+ import sys
10
+
11
+ # Add parent directory to path to import the preprocess functions
12
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
13
+ from preprocess import process_audio_data, process_image_data
14
+
15
+ # Print library versions
16
+ print(f"\033[92mINFO\033[0m: PyTorch version: {torch.__version__}")
17
+ print(f"\033[92mINFO\033[0m: Torchaudio version: {torchaudio.__version__}")
18
+ print(f"\033[92mINFO\033[0m: Torchvision version: {torchvision.__version__}")
19
+
20
+ # Device selection
21
+ device = torch.device(
22
+ "cuda"
23
+ if torch.cuda.is_available()
24
+ else "mps" if torch.backends.mps.is_available() else "cpu"
25
+ )
26
+ print(f"\033[92mINFO\033[0m: Using device: {device}")
27
+
28
+ # Hyperparameters
29
+ batch_size = 16
30
+ epochs = 2
31
+ learning_rate = 0.0001
32
+
33
+ # Model save directory
34
+ os.makedirs("models/", exist_ok=True)
35
+
36
+
37
+ class WatermelonDataset(Dataset):
38
+ def __init__(self, data_dir):
39
+ self.data_dir = data_dir
40
+ self.samples = []
41
+
42
+ # Walk through the directory structure
43
+ for sweetness_dir in os.listdir(data_dir):
44
+ sweetness = float(sweetness_dir)
45
+ sweetness_path = os.path.join(data_dir, sweetness_dir)
46
+
47
+ if os.path.isdir(sweetness_path):
48
+ for id_dir in os.listdir(sweetness_path):
49
+ id_path = os.path.join(sweetness_path, id_dir)
50
+
51
+ if os.path.isdir(id_path):
52
+ audio_file = os.path.join(id_path, f"{id_dir}.wav")
53
+ image_file = os.path.join(id_path, f"{id_dir}.jpg")
54
+
55
+ if os.path.exists(audio_file) and os.path.exists(image_file):
56
+ self.samples.append((audio_file, image_file, sweetness))
57
+
58
+ print(f"\033[92mINFO\033[0m: Loaded {len(self.samples)} samples from {data_dir}")
59
+
60
+ def __len__(self):
61
+ return len(self.samples)
62
+
63
+ def __getitem__(self, idx):
64
+ audio_path, image_path, label = self.samples[idx]
65
+
66
+ # Load and process audio
67
+ try:
68
+ waveform, sample_rate = torchaudio.load(audio_path)
69
+ mfcc = process_audio_data(waveform, sample_rate)
70
+
71
+ # Load and process image
72
+ image = torchvision.io.read_image(image_path)
73
+ image = image.float()
74
+ processed_image = process_image_data(image)
75
+
76
+ return mfcc, processed_image, torch.tensor(label).float()
77
+ except Exception as e:
78
+ print(f"\033[91mERR!\033[0m: Error processing sample {idx}: {e}")
79
+ # Return a fallback sample or skip this sample
80
+ # For simplicity, we'll return the first sample again
81
+ if idx == 0: # Prevent infinite recursion
82
+ raise e
83
+ return self.__getitem__(0)
84
+
85
+
86
+ class WatermelonModel(torch.nn.Module):
87
+ def __init__(self):
88
+ super(WatermelonModel, self).__init__()
89
+
90
+ # LSTM for audio features
91
+ self.lstm = torch.nn.LSTM(
92
+ input_size=376, hidden_size=64, num_layers=2, batch_first=True
93
+ )
94
+ self.lstm_fc = torch.nn.Linear(
95
+ 64, 128
96
+ ) # Convert LSTM output to 128-dim for merging
97
+
98
+ # ResNet50 for image features
99
+ self.resnet = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
100
+ self.resnet.fc = torch.nn.Linear(
101
+ self.resnet.fc.in_features, 128
102
+ ) # Convert ResNet output to 128-dim for merging
103
+
104
+ # Fully connected layers for final prediction
105
+ self.fc1 = torch.nn.Linear(256, 64)
106
+ self.fc2 = torch.nn.Linear(64, 1)
107
+ self.relu = torch.nn.ReLU()
108
+
109
+ def forward(self, mfcc, image):
110
+ # LSTM branch
111
+ lstm_output, _ = self.lstm(mfcc)
112
+ lstm_output = lstm_output[:, -1, :] # Use the output of the last time step
113
+ lstm_output = self.lstm_fc(lstm_output)
114
+
115
+ # ResNet branch
116
+ resnet_output = self.resnet(image)
117
+
118
+ # Concatenate LSTM and ResNet outputs
119
+ merged = torch.cat((lstm_output, resnet_output), dim=1)
120
+
121
+ # Fully connected layers
122
+ output = self.relu(self.fc1(merged))
123
+ output = self.fc2(output)
124
+
125
+ return output
126
+
127
+
128
+ def train_model(data_dir, output_dir="models/"):
129
+ # Create dataset
130
+ dataset = WatermelonDataset(data_dir)
131
+ n_samples = len(dataset)
132
+
133
+ # Split dataset
134
+ train_size = int(0.7 * n_samples)
135
+ val_size = int(0.2 * n_samples)
136
+ test_size = n_samples - train_size - val_size
137
+
138
+ train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
139
+ dataset, [train_size, val_size, test_size]
140
+ )
141
+
142
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
143
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
144
+
145
+ # Initialize model
146
+ model = WatermelonModel().to(device)
147
+
148
+ # Loss function and optimizer
149
+ criterion = torch.nn.MSELoss()
150
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
151
+
152
+ # TensorBoard
153
+ writer = SummaryWriter("runs/")
154
+ global_step = 0
155
+
156
+ print(f"\033[92mINFO\033[0m: Training model for {epochs} epochs")
157
+ print(f"\033[92mINFO\033[0m: Training samples: {len(train_dataset)}")
158
+ print(f"\033[92mINFO\033[0m: Validation samples: {len(val_dataset)}")
159
+ print(f"\033[92mINFO\033[0m: Test samples: {len(test_dataset)}")
160
+ print(f"\033[92mINFO\033[0m: Batch size: {batch_size}")
161
+
162
+ # Training loop
163
+ for epoch in range(epochs):
164
+ print(f"\033[92mINFO\033[0m: Training epoch ({epoch+1}/{epochs})")
165
+
166
+ model.train()
167
+ running_loss = 0.0
168
+
169
+ for i, (mfcc, image, label) in enumerate(train_loader):
170
+ try:
171
+ mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
172
+
173
+ optimizer.zero_grad()
174
+ output = model(mfcc, image)
175
+ label = label.view(-1, 1).float()
176
+ loss = criterion(output, label)
177
+ loss.backward()
178
+ optimizer.step()
179
+
180
+ running_loss += loss.item()
181
+ writer.add_scalar("Training Loss", loss.item(), global_step)
182
+ global_step += 1
183
+
184
+ if i % 10 == 0:
185
+ print(f"\033[92mINFO\033[0m: Batch {i}/{len(train_loader)}, Loss: {loss.item():.4f}")
186
+
187
+ except Exception as e:
188
+ print(f"\033[91mERR!\033[0m: Error in training batch {i}: {e}")
189
+ continue
190
+
191
+ # Validation phase
192
+ model.eval()
193
+ val_loss = 0.0
194
+ with torch.no_grad():
195
+ for i, (mfcc, image, label) in enumerate(val_loader):
196
+ try:
197
+ mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
198
+ output = model(mfcc, image)
199
+ label = label.view(-1, 1).float()
200
+ loss = criterion(output, label)
201
+ val_loss += loss.item()
202
+ except Exception as e:
203
+ print(f"\033[91mERR!\033[0m: Error in validation batch {i}: {e}")
204
+ continue
205
+
206
+ avg_train_loss = running_loss / len(train_loader) if len(train_loader) > 0 else float('inf')
207
+ avg_val_loss = val_loss / len(val_loader) if len(val_loader) > 0 else float('inf')
208
+
209
+ # Record validation loss
210
+ writer.add_scalar("Validation Loss", avg_val_loss, epoch)
211
+
212
+ print(
213
+ f"Epoch [{epoch+1}/{epochs}], Training Loss: {avg_train_loss:.4f}, "
214
+ f"Validation Loss: {avg_val_loss:.4f}"
215
+ )
216
+
217
+ # Save model checkpoint
218
+ timestamp = time.strftime("%Y%m%d-%H%M%S")
219
+ model_path = os.path.join(output_dir, f"model_{epoch+1}_{timestamp}.pt")
220
+ torch.save(model.state_dict(), model_path)
221
+
222
+ print(
223
+ f"\033[92mINFO\033[0m: Model checkpoint epoch [{epoch+1}/{epochs}] saved: {model_path}"
224
+ )
225
+
226
+ # Save final model
227
+ final_model_path = os.path.join(output_dir, "watermelon_model_final.pt")
228
+ torch.save(model.state_dict(), final_model_path)
229
+ print(f"\033[92mINFO\033[0m: Final model saved: {final_model_path}")
230
+
231
+ print(f"\033[92mINFO\033[0m: Training complete")
232
+ return final_model_path
233
+
234
+
235
+ if __name__ == "__main__":
236
+ import argparse
237
+
238
+ parser = argparse.ArgumentParser(description="Train the Watermelon Sweetness Prediction Model")
239
+ parser.add_argument(
240
+ "--data_dir",
241
+ type=str,
242
+ default="../cleaned",
243
+ help="Path to the cleaned dataset directory"
244
+ )
245
+ parser.add_argument(
246
+ "--output_dir",
247
+ type=str,
248
+ default="models/",
249
+ help="Directory to save model checkpoints and the final model"
250
+ )
251
+
252
+ args = parser.parse_args()
253
+
254
+ # Ensure output directory exists
255
+ os.makedirs(args.output_dir, exist_ok=True)
256
+
257
+ # Train the model
258
+ final_model_path = train_model(args.data_dir, args.output_dir)
259
+
260
+ print(f"\033[92mINFO\033[0m: Training completed successfully!")
261
+ print(f"\033[92mINFO\033[0m: Final model saved at: {final_model_path}")