Xalphinions commited on
Commit
34c4a97
·
verified ·
1 Parent(s): 6660140

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +231 -211
  2. requirements.txt +1 -0
app.py CHANGED
@@ -2,15 +2,29 @@ import torch, torchaudio, torchvision
2
  import os
3
  import gradio as gr
4
  import numpy as np
 
5
 
6
  from preprocess import process_audio_data, process_image_data
7
  from train import WatermelonModel
8
  from infer import infer
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def load_model(model_path):
11
  global device
12
  device = torch.device(
13
- "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
14
  )
15
  print(f"\033[92mINFO\033[0m: Using device: {device}")
16
 
@@ -39,231 +53,237 @@ def load_model(model_path):
39
  print(f"File size: {file_size} bytes")
40
  raise
41
 
42
- if __name__ == "__main__":
43
- import argparse
44
-
45
- parser = argparse.ArgumentParser(description="Watermelon sweetness predictor")
46
- parser.add_argument("--model_path", type=str, default="./models/model_15_20250405-033557.pt", help="Path to the trained model")
47
- args = parser.parse_args()
48
-
49
- model = load_model(args.model_path)
50
-
51
- def predict(audio, image):
52
- try:
53
- # Debug audio input
54
- print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
55
- print(f"\033[92mDEBUG\033[0m: Audio input value: {audio}")
56
 
57
- # Handle different formats of audio input from Gradio
58
- if audio is None:
59
- return "Error: No audio provided. Please upload or record audio."
60
-
61
- if isinstance(audio, tuple) and len(audio) >= 2:
62
- sr, audio_data = audio[0], audio[-1]
63
- print(f"\033[92mDEBUG\033[0m: Audio format: sr={sr}, audio_data shape={audio_data.shape if hasattr(audio_data, 'shape') else 'no shape'}")
64
- elif isinstance(audio, tuple) and len(audio) == 1:
65
- # Handle single element tuple
66
- audio_data = audio[0]
67
- sr = 44100 # Assume default sample rate
68
- print(f"\033[92mDEBUG\033[0m: Single element audio tuple, using default sr={sr}")
69
- elif isinstance(audio, np.ndarray):
70
- # Handle direct numpy array
71
- audio_data = audio
72
- sr = 44100 # Assume default sample rate
73
- print(f"\033[92mDEBUG\033[0m: Audio is numpy array, using default sr={sr}")
74
- else:
75
- return f"Error: Unexpected audio format: {type(audio)}"
 
 
 
76
 
77
- # Ensure audio_data is correctly shaped
78
- if isinstance(audio_data, np.ndarray):
79
- # Make sure we have a 2D array
80
- if len(audio_data.shape) == 1:
81
- audio_data = np.expand_dims(audio_data, axis=0)
82
- print(f"\033[92mDEBUG\033[0m: Reshaped 1D audio to 2D: {audio_data.shape}")
83
-
84
- # If channels are the second dimension, transpose
85
- if len(audio_data.shape) == 2 and audio_data.shape[0] > audio_data.shape[1]:
86
- audio_data = np.transpose(audio_data)
87
- print(f"\033[92mDEBUG\033[0m: Transposed audio shape to: {audio_data.shape}")
 
 
88
 
89
- # Convert to tensor
90
- audio_tensor = torch.tensor(audio_data).float()
91
- print(f"\033[92mDEBUG\033[0m: Audio tensor shape: {audio_tensor.shape}")
 
 
 
 
 
 
 
92
 
93
- # Process audio data and handle None case
94
- mfcc = process_audio_data(audio_tensor, sr)
95
- if mfcc is None:
96
- return "Error: Failed to process audio data. Make sure your audio contains a clear tapping sound."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- mfcc = mfcc.to(device)
99
- print(f"\033[92mDEBUG\033[0m: MFCC shape: {mfcc.shape}")
100
-
101
- # Debug image input
102
- print(f"\033[92mDEBUG\033[0m: Image input type: {type(image)}")
103
- print(f"\033[92mDEBUG\033[0m: Image shape: {image.shape if hasattr(image, 'shape') else 'No shape'}")
 
 
 
 
 
 
104
 
105
- # Process image data and handle None case
106
- if image is None:
107
- return "Error: No image provided. Please upload an image."
 
 
 
108
 
109
- # Handle different image formats
110
- if isinstance(image, np.ndarray):
111
- # Check if image is properly formatted (H, W, C) with 3 channels
112
- if len(image.shape) == 3 and image.shape[2] == 3:
113
- # Convert to tensor with shape (C, H, W) as expected by PyTorch
114
- img = torch.tensor(image).float().permute(2, 0, 1)
115
- print(f"\033[92mDEBUG\033[0m: Converted image to tensor with shape: {img.shape}")
116
- elif len(image.shape) == 2:
117
- # Grayscale image, expand to 3 channels
118
- img = torch.tensor(image).float().unsqueeze(0).repeat(3, 1, 1)
119
- print(f"\033[92mDEBUG\033[0m: Converted grayscale image to RGB tensor with shape: {img.shape}")
120
- else:
121
- return f"Error: Unexpected image shape: {image.shape}. Expected RGB or grayscale image."
122
- else:
123
- return f"Error: Unexpected image format: {type(image)}. Expected numpy array."
124
-
125
- # Scale pixel values to [0, 1] if needed
126
- if img.max() > 1.0:
127
- img = img / 255.0
128
- print(f"\033[92mDEBUG\033[0m: Scaled image pixel values to range [0, 1]")
129
-
130
- # Get image dimensions and check if they're reasonable
131
- print(f"\033[92mDEBUG\033[0m: Final image tensor shape before processing: {img.shape}")
132
-
133
- # Process image
134
- try:
135
- img_processed = process_image_data(img)
136
- if img_processed is None:
137
- return "Error: Failed to process image data. Make sure your image clearly shows a watermelon."
138
-
139
- img_processed = img_processed.to(device)
140
- print(f"\033[92mDEBUG\033[0m: Processed image shape: {img_processed.shape}")
141
- except Exception as e:
142
- print(f"\033[91mERROR\033[0m: Image processing error: {str(e)}")
143
- return f"Error in image processing: {str(e)}"
144
-
145
- # Run inference
146
- try:
147
- # Based on the error, it seems infer() expects file paths, not tensors
148
- # Let's create temporary files for the processed data
149
- temp_dir = os.path.join(os.getcwd(), "temp")
150
- os.makedirs(temp_dir, exist_ok=True)
151
 
152
- # Save the audio to a temporary file if infer expects a file path
153
- temp_audio_path = os.path.join(temp_dir, "temp_audio.wav")
154
- if not isinstance(audio, str) and isinstance(audio, tuple) and len(audio) >= 2:
155
- # If we have the original audio data and sample rate
156
- audio_array = audio[-1]
157
- sr = audio[0]
158
-
159
- # Check if the audio array is valid
160
- if audio_array.size == 0:
161
- return "Error: Audio data is empty. Please record a longer audio clip."
162
-
163
- # Get the duration of the audio
164
- duration = audio_array.shape[-1] / sr
165
- print(f"\033[92mDEBUG\033[0m: Audio duration: {duration:.2f} seconds")
166
-
167
- # Check if we have at least 1 second of audio - but don't reject, just pad if needed
168
- min_duration = 1.0 # minimum 1 second of audio
169
- if duration < min_duration:
170
- print(f"\033[93mWARNING\033[0m: Audio is shorter than {min_duration} seconds. Padding will be applied.")
171
- # Calculate samples needed to reach minimum duration
172
- samples_needed = int(min_duration * sr) - audio_array.shape[-1]
173
- # Pad with zeros
174
- padding = np.zeros((audio_array.shape[0], samples_needed), dtype=audio_array.dtype)
175
- audio_array = np.concatenate([audio_array, padding], axis=1)
176
- print(f"\033[92mDEBUG\033[0m: Padded audio to shape: {audio_array.shape}")
177
-
178
- # Make sure audio has 2 dimensions
179
- if len(audio_array.shape) == 1:
180
- audio_array = np.expand_dims(audio_array, axis=0)
181
-
182
- print(f"\033[92mDEBUG\033[0m: Audio array shape before saving: {audio_array.shape}, sr: {sr}")
183
-
184
- # Make sure it's in the right format for torchaudio.save
185
- audio_tensor = torch.tensor(audio_array).float()
186
- if audio_tensor.dim() == 1:
187
- audio_tensor = audio_tensor.unsqueeze(0)
188
-
189
- torchaudio.save(temp_audio_path, audio_tensor, sr)
190
- print(f"\033[92mDEBUG\033[0m: Saved temporary audio file to {temp_audio_path}")
191
-
192
- # Let's also process the audio here to verify it works
193
- test_mfcc = process_audio_data(audio_tensor, sr)
194
- if test_mfcc is None:
195
- return "Error: Unable to process the audio. Please try recording a different audio sample."
196
- else:
197
- print(f"\033[92mDEBUG\033[0m: Audio pre-check passed. MFCC shape: {test_mfcc.shape}")
198
-
199
- audio_path = temp_audio_path
200
- else:
201
- # If we don't have a valid path, return an error
202
- return "Error: Cannot process audio for inference. Invalid audio format."
203
 
204
- # Save the image to a temporary file if infer expects a file path
205
- temp_image_path = os.path.join(temp_dir, "temp_image.jpg")
206
- if isinstance(image, np.ndarray):
207
- import cv2
208
- cv2.imwrite(temp_image_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
209
- print(f"\033[92mDEBUG\033[0m: Saved temporary image file to {temp_image_path}")
210
- image_path = temp_image_path
211
  else:
212
- # If we don't have a valid image, return an error
213
- return "Error: Cannot process image for inference. Invalid image format."
214
 
215
- # Create a modified version of infer that handles None returns
216
- def safe_infer(audio_path, image_path, model, device):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  try:
218
- return infer(audio_path, image_path, model, device)
219
- except Exception as e:
220
- print(f"\033[91mERROR\033[0m: Error in infer function: {str(e)}")
221
- # Try a more direct approach
222
- try:
223
- # Load audio and process
224
- audio, sr = torchaudio.load(audio_path)
225
- mfcc = process_audio_data(audio, sr)
226
- if mfcc is None:
227
- raise ValueError("Audio processing failed - MFCC is None")
228
- mfcc = mfcc.to(device)
229
-
230
- # Load image and process
231
- image = cv2.imread(image_path)
232
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
233
- image_tensor = torch.tensor(image).float().permute(2, 0, 1) / 255.0
234
- img_processed = process_image_data(image_tensor)
235
- if img_processed is None:
236
- raise ValueError("Image processing failed - processed image is None")
237
- img_processed = img_processed.to(device)
238
-
239
- # Run model inference
240
- with torch.no_grad():
241
- prediction = model(mfcc, img_processed)
242
- return prediction
243
- except Exception as e2:
244
- print(f"\033[91mERROR\033[0m: Fallback inference also failed: {str(e2)}")
245
- raise
246
-
247
- # Call our safer version
248
- print(f"\033[92mDEBUG\033[0m: Calling safe_infer with audio_path={audio_path}, image_path={image_path}")
249
- sweetness = safe_infer(audio_path, image_path, model, device)
250
- if sweetness is None:
251
- return "Error: The model was unable to make a prediction. Please try with different inputs."
252
-
253
- print(f"\033[92mDEBUG\033[0m: Inference result: {sweetness.item()}")
254
- return f"Predicted Sweetness: {sweetness.item():.2f}/10"
255
- except Exception as e:
256
- import traceback
257
- print(f"\033[91mERROR\033[0m: Inference failed: {str(e)}")
258
- print(f"\033[91mTraceback\033[0m: {traceback.format_exc()}")
259
- return f"Error during inference: {str(e)}"
260
 
 
 
 
 
 
 
 
 
261
  except Exception as e:
262
- import traceback
263
- print(f"\033[91mERROR\033[0m: Prediction failed: {str(e)}")
264
  print(f"\033[91mTraceback\033[0m: {traceback.format_exc()}")
265
- return f"Error processing input: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
 
267
  audio_input = gr.Audio(label="Upload or Record Audio")
268
  image_input = gr.Image(label="Upload or Capture Image")
269
  output = gr.Textbox(label="Predicted Sweetness")
@@ -277,7 +297,7 @@ if __name__ == "__main__":
277
  )
278
 
279
  try:
280
- interface.launch() # Enable sharing to avoid localhost access issues
281
  except Exception as e:
282
  print(f"\033[91mERROR\033[0m: Failed to launch interface: {e}")
283
  print("\033[93mTIP\033[0m: If you're running in a remote environment or container, try setting additional parameters:")
 
2
  import os
3
  import gradio as gr
4
  import numpy as np
5
+ import traceback
6
 
7
  from preprocess import process_audio_data, process_image_data
8
  from train import WatermelonModel
9
  from infer import infer
10
 
11
+ # Add HuggingFace Spaces GPU decorator
12
+ try:
13
+ import spaces
14
+ use_gpu_decorator = True
15
+ print("\033[92mINFO\033[0m: HuggingFace Spaces GPU support detected")
16
+ except ImportError:
17
+ use_gpu_decorator = False
18
+ print("\033[93mWARNING\033[0m: HuggingFace Spaces GPU support not detected, running in standard mode")
19
+
20
+ # Global device variable
21
+ device = None
22
+
23
+ @spaces.GPU
24
  def load_model(model_path):
25
  global device
26
  device = torch.device(
27
+ "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
28
  )
29
  print(f"\033[92mINFO\033[0m: Using device: {device}")
30
 
 
53
  print(f"File size: {file_size} bytes")
54
  raise
55
 
56
+ # Define the main prediction function
57
+ def predict_impl(audio, image, model):
58
+ try:
59
+ # Debug audio input
60
+ print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
61
+ print(f"\033[92mDEBUG\033[0m: Audio input value: {audio}")
62
+
63
+ # Handle different formats of audio input from Gradio
64
+ if audio is None:
65
+ return "Error: No audio provided. Please upload or record audio."
 
 
 
 
66
 
67
+ if isinstance(audio, tuple) and len(audio) >= 2:
68
+ sr, audio_data = audio[0], audio[-1]
69
+ print(f"\033[92mDEBUG\033[0m: Audio format: sr={sr}, audio_data shape={audio_data.shape if hasattr(audio_data, 'shape') else 'no shape'}")
70
+ elif isinstance(audio, tuple) and len(audio) == 1:
71
+ # Handle single element tuple
72
+ audio_data = audio[0]
73
+ sr = 44100 # Assume default sample rate
74
+ print(f"\033[92mDEBUG\033[0m: Single element audio tuple, using default sr={sr}")
75
+ elif isinstance(audio, np.ndarray):
76
+ # Handle direct numpy array
77
+ audio_data = audio
78
+ sr = 44100 # Assume default sample rate
79
+ print(f"\033[92mDEBUG\033[0m: Audio is numpy array, using default sr={sr}")
80
+ else:
81
+ return f"Error: Unexpected audio format: {type(audio)}"
82
+
83
+ # Ensure audio_data is correctly shaped
84
+ if isinstance(audio_data, np.ndarray):
85
+ # Make sure we have a 2D array
86
+ if len(audio_data.shape) == 1:
87
+ audio_data = np.expand_dims(audio_data, axis=0)
88
+ print(f"\033[92mDEBUG\033[0m: Reshaped 1D audio to 2D: {audio_data.shape}")
89
 
90
+ # If channels are the second dimension, transpose
91
+ if len(audio_data.shape) == 2 and audio_data.shape[0] > audio_data.shape[1]:
92
+ audio_data = np.transpose(audio_data)
93
+ print(f"\033[92mDEBUG\033[0m: Transposed audio shape to: {audio_data.shape}")
94
+
95
+ # Convert to tensor
96
+ audio_tensor = torch.tensor(audio_data).float()
97
+ print(f"\033[92mDEBUG\033[0m: Audio tensor shape: {audio_tensor.shape}")
98
+
99
+ # Process audio data and handle None case
100
+ mfcc = process_audio_data(audio_tensor, sr)
101
+ if mfcc is None:
102
+ return "Error: Failed to process audio data. Make sure your audio contains a clear tapping sound."
103
 
104
+ mfcc = mfcc.to(device)
105
+ print(f"\033[92mDEBUG\033[0m: MFCC shape: {mfcc.shape}")
106
+
107
+ # Debug image input
108
+ print(f"\033[92mDEBUG\033[0m: Image input type: {type(image)}")
109
+ print(f"\033[92mDEBUG\033[0m: Image shape: {image.shape if hasattr(image, 'shape') else 'No shape'}")
110
+
111
+ # Process image data and handle None case
112
+ if image is None:
113
+ return "Error: No image provided. Please upload an image."
114
 
115
+ # Handle different image formats
116
+ if isinstance(image, np.ndarray):
117
+ # Check if image is properly formatted (H, W, C) with 3 channels
118
+ if len(image.shape) == 3 and image.shape[2] == 3:
119
+ # Convert to tensor with shape (C, H, W) as expected by PyTorch
120
+ img = torch.tensor(image).float().permute(2, 0, 1)
121
+ print(f"\033[92mDEBUG\033[0m: Converted image to tensor with shape: {img.shape}")
122
+ elif len(image.shape) == 2:
123
+ # Grayscale image, expand to 3 channels
124
+ img = torch.tensor(image).float().unsqueeze(0).repeat(3, 1, 1)
125
+ print(f"\033[92mDEBUG\033[0m: Converted grayscale image to RGB tensor with shape: {img.shape}")
126
+ else:
127
+ return f"Error: Unexpected image shape: {image.shape}. Expected RGB or grayscale image."
128
+ else:
129
+ return f"Error: Unexpected image format: {type(image)}. Expected numpy array."
130
+
131
+ # Scale pixel values to [0, 1] if needed
132
+ if img.max() > 1.0:
133
+ img = img / 255.0
134
+ print(f"\033[92mDEBUG\033[0m: Scaled image pixel values to range [0, 1]")
135
+
136
+ # Get image dimensions and check if they're reasonable
137
+ print(f"\033[92mDEBUG\033[0m: Final image tensor shape before processing: {img.shape}")
138
+
139
+ # Process image
140
+ try:
141
+ img_processed = process_image_data(img)
142
+ if img_processed is None:
143
+ return "Error: Failed to process image data. Make sure your image clearly shows a watermelon."
144
 
145
+ img_processed = img_processed.to(device)
146
+ print(f"\033[92mDEBUG\033[0m: Processed image shape: {img_processed.shape}")
147
+ except Exception as e:
148
+ print(f"\033[91mERROR\033[0m: Image processing error: {str(e)}")
149
+ return f"Error in image processing: {str(e)}"
150
+
151
+ # Run inference
152
+ try:
153
+ # Based on the error, it seems infer() expects file paths, not tensors
154
+ # Let's create temporary files for the processed data
155
+ temp_dir = os.path.join(os.getcwd(), "temp")
156
+ os.makedirs(temp_dir, exist_ok=True)
157
 
158
+ # Save the audio to a temporary file if infer expects a file path
159
+ temp_audio_path = os.path.join(temp_dir, "temp_audio.wav")
160
+ if not isinstance(audio, str) and isinstance(audio, tuple) and len(audio) >= 2:
161
+ # If we have the original audio data and sample rate
162
+ audio_array = audio[-1]
163
+ sr = audio[0]
164
 
165
+ # Check if the audio array is valid
166
+ if audio_array.size == 0:
167
+ return "Error: Audio data is empty. Please record a longer audio clip."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
+ # Get the duration of the audio
170
+ duration = audio_array.shape[-1] / sr
171
+ print(f"\033[92mDEBUG\033[0m: Audio duration: {duration:.2f} seconds")
172
+
173
+ # Check if we have at least 1 second of audio - but don't reject, just pad if needed
174
+ min_duration = 1.0 # minimum 1 second of audio
175
+ if duration < min_duration:
176
+ print(f"\033[93mWARNING\033[0m: Audio is shorter than {min_duration} seconds. Padding will be applied.")
177
+ # Calculate samples needed to reach minimum duration
178
+ samples_needed = int(min_duration * sr) - audio_array.shape[-1]
179
+ # Pad with zeros
180
+ padding = np.zeros((audio_array.shape[0], samples_needed), dtype=audio_array.dtype)
181
+ audio_array = np.concatenate([audio_array, padding], axis=1)
182
+ print(f"\033[92mDEBUG\033[0m: Padded audio to shape: {audio_array.shape}")
183
+
184
+ # Make sure audio has 2 dimensions
185
+ if len(audio_array.shape) == 1:
186
+ audio_array = np.expand_dims(audio_array, axis=0)
187
+
188
+ print(f"\033[92mDEBUG\033[0m: Audio array shape before saving: {audio_array.shape}, sr: {sr}")
189
+
190
+ # Make sure it's in the right format for torchaudio.save
191
+ audio_tensor = torch.tensor(audio_array).float()
192
+ if audio_tensor.dim() == 1:
193
+ audio_tensor = audio_tensor.unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
+ torchaudio.save(temp_audio_path, audio_tensor, sr)
196
+ print(f"\033[92mDEBUG\033[0m: Saved temporary audio file to {temp_audio_path}")
197
+
198
+ # Let's also process the audio here to verify it works
199
+ test_mfcc = process_audio_data(audio_tensor, sr)
200
+ if test_mfcc is None:
201
+ return "Error: Unable to process the audio. Please try recording a different audio sample."
202
  else:
203
+ print(f"\033[92mDEBUG\033[0m: Audio pre-check passed. MFCC shape: {test_mfcc.shape}")
 
204
 
205
+ audio_path = temp_audio_path
206
+ else:
207
+ # If we don't have a valid path, return an error
208
+ return "Error: Cannot process audio for inference. Invalid audio format."
209
+
210
+ # Save the image to a temporary file if infer expects a file path
211
+ temp_image_path = os.path.join(temp_dir, "temp_image.jpg")
212
+ if isinstance(image, np.ndarray):
213
+ import cv2
214
+ cv2.imwrite(temp_image_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
215
+ print(f"\033[92mDEBUG\033[0m: Saved temporary image file to {temp_image_path}")
216
+ image_path = temp_image_path
217
+ else:
218
+ # If we don't have a valid image, return an error
219
+ return "Error: Cannot process image for inference. Invalid image format."
220
+
221
+ # Create a modified version of infer that handles None returns
222
+ def safe_infer(audio_path, image_path, model, device):
223
+ try:
224
+ return infer(audio_path, image_path, model, device)
225
+ except Exception as e:
226
+ print(f"\033[91mERROR\033[0m: Error in infer function: {str(e)}")
227
+ # Try a more direct approach
228
  try:
229
+ # Load audio and process
230
+ audio, sr = torchaudio.load(audio_path)
231
+ mfcc = process_audio_data(audio, sr)
232
+ if mfcc is None:
233
+ raise ValueError("Audio processing failed - MFCC is None")
234
+ mfcc = mfcc.to(device)
235
+
236
+ # Load image and process
237
+ image = cv2.imread(image_path)
238
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
239
+ image_tensor = torch.tensor(image).float().permute(2, 0, 1) / 255.0
240
+ img_processed = process_image_data(image_tensor)
241
+ if img_processed is None:
242
+ raise ValueError("Image processing failed - processed image is None")
243
+ img_processed = img_processed.to(device)
244
+
245
+ # Run model inference
246
+ with torch.no_grad():
247
+ prediction = model(mfcc, img_processed)
248
+ return prediction
249
+ except Exception as e2:
250
+ print(f"\033[91mERROR\033[0m: Fallback inference also failed: {str(e2)}")
251
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
+ # Call our safer version
254
+ print(f"\033[92mDEBUG\033[0m: Calling safe_infer with audio_path={audio_path}, image_path={image_path}")
255
+ sweetness = safe_infer(audio_path, image_path, model, device)
256
+ if sweetness is None:
257
+ return "Error: The model was unable to make a prediction. Please try with different inputs."
258
+
259
+ print(f"\033[92mDEBUG\033[0m: Inference result: {sweetness.item()}")
260
+ return f"Predicted Sweetness: {sweetness.item():.2f}/10"
261
  except Exception as e:
262
+ print(f"\033[91mERROR\033[0m: Inference failed: {str(e)}")
 
263
  print(f"\033[91mTraceback\033[0m: {traceback.format_exc()}")
264
+ return f"Error during inference: {str(e)}"
265
+
266
+ except Exception as e:
267
+ print(f"\033[91mERROR\033[0m: Prediction failed: {str(e)}")
268
+ print(f"\033[91mTraceback\033[0m: {traceback.format_exc()}")
269
+ return f"Error processing input: {str(e)}"
270
+
271
+ if __name__ == "__main__":
272
+ import argparse
273
+
274
+ parser = argparse.ArgumentParser(description="Watermelon sweetness predictor")
275
+ parser.add_argument("--model_path", type=str, default="./models/model_15_20250405-033557.pt", help="Path to the trained model")
276
+ args = parser.parse_args()
277
+
278
+
279
+ # Create wrapper function for Gradio that passes the model
280
+ @spaces.GPU
281
+ def predict(audio, image):
282
+ model = load_model(args.model_path)
283
+ return predict_impl(audio, image, model)
284
+ print("\033[92mINFO\033[0m: GPU acceleration enabled via @spaces.GPU decorator")
285
 
286
+ # Set up Gradio interface
287
  audio_input = gr.Audio(label="Upload or Record Audio")
288
  image_input = gr.Image(label="Upload or Capture Image")
289
  output = gr.Textbox(label="Predicted Sweetness")
 
297
  )
298
 
299
  try:
300
+ interface.launch() # Launch the interface
301
  except Exception as e:
302
  print(f"\033[91mERROR\033[0m: Failed to launch interface: {e}")
303
  print("\033[93mTIP\033[0m: If you're running in a remote environment or container, try setting additional parameters:")
requirements.txt CHANGED
@@ -13,6 +13,7 @@ numpy==1.24.2
13
  Pillow==9.4.0
14
  tensorboard==2.13.0
15
  pydantic==2.10.6
 
16
 
17
  # Audio processing
18
  soundfile==0.12.1
 
13
  Pillow==9.4.0
14
  tensorboard==2.13.0
15
  pydantic==2.10.6
16
+ huggingface-hub>=0.15.1
17
 
18
  # Audio processing
19
  soundfile==0.12.1