Xalphinions commited on
Commit
dd995d1
Β·
verified Β·
1 Parent(s): 13b45d3

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +89 -180
  2. app_moe.py +13 -3
app.py CHANGED
@@ -34,11 +34,10 @@ class WatermelonMoEModel(torch.nn.Module):
34
  weights: Optional list of weights for each model (None for equal weighting)
35
  """
36
  super(WatermelonMoEModel, self).__init__()
37
- self.models = torch.nn.ModuleList() # Use ModuleList instead of regular list
38
  self.model_configs = model_configs
39
 
40
  # Load each model
41
- loaded_count = 0
42
  for config in model_configs:
43
  img_backbone = config["image_backbone"]
44
  audio_backbone = config["audio_backbone"]
@@ -50,31 +49,22 @@ class WatermelonMoEModel(torch.nn.Module):
50
  model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt")
51
  if os.path.exists(model_path):
52
  print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}")
53
- try:
54
- model.load_state_dict(torch.load(model_path, map_location='cpu'))
55
- model.eval() # Set to evaluation mode
56
- self.models.append(model)
57
- loaded_count += 1
58
- except Exception as e:
59
- print(f"\033[91mERR!\033[0m: Failed to load model from {model_path}: {e}")
60
- continue
61
  else:
62
  print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}")
63
  continue
 
 
 
64
 
65
- # Add a dummy parameter if no models were loaded to prevent StopIteration
66
- if loaded_count == 0:
67
- print(f"\033[91mERR!\033[0m: No models were successfully loaded!")
68
- self.dummy_param = torch.nn.Parameter(torch.zeros(1))
69
-
70
  # Set model weights (uniform by default)
71
- if weights and loaded_count > 0:
72
  assert len(weights) == len(self.models), "Number of weights must match number of models"
73
  self.weights = weights
74
  else:
75
- self.weights = [1.0 / max(loaded_count, 1)] * max(loaded_count, 1)
76
 
77
- print(f"\033[92mINFO\033[0m: Loaded {loaded_count} models for MoE ensemble")
78
  print(f"\033[92mINFO\033[0m: Model weights: {self.weights}")
79
 
80
  def to(self, device):
@@ -90,10 +80,9 @@ class WatermelonMoEModel(torch.nn.Module):
90
  Forward pass through the MoE model.
91
  Returns the weighted average of all model outputs.
92
  """
93
- # Check if we have models loaded
94
  if not self.models:
95
  print(f"\033[91mERR!\033[0m: No models available for inference!")
96
- return torch.tensor([0.0], device=mfcc.device) # Return a default value
97
 
98
  outputs = []
99
 
@@ -101,6 +90,8 @@ class WatermelonMoEModel(torch.nn.Module):
101
  with torch.no_grad():
102
  for i, model in enumerate(self.models):
103
  output = model(mfcc, image)
 
 
104
  outputs.append(output * self.weights[i])
105
 
106
  # Return weighted average
@@ -166,196 +157,114 @@ def predict_sugar_content(audio, image, model_dir="models", weights=None):
166
  """Function with GPU acceleration to predict watermelon sugar content in Brix using MoE model"""
167
  try:
168
  # Check CUDA availability inside the GPU-decorated function
169
- if torch.cuda.is_available():
170
- device = torch.device("cuda")
171
- print(f"\033[92mINFO\033[0m: CUDA is available. Using device: {device}")
172
- else:
173
- device = torch.device("cpu")
174
- print(f"\033[92mINFO\033[0m: CUDA is not available. Using device: {device}")
175
 
176
  # Load MoE model
177
  moe_model = WatermelonMoEModel(TOP_MODELS, model_dir, weights)
178
- # Explicitly move the entire model to device
179
- moe_model = moe_model.to(device)
180
  moe_model.eval()
181
  print(f"\033[92mINFO\033[0m: Loaded MoE model with {len(moe_model.models)} backbone models")
182
 
183
- # Debug information about input types
184
- print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
185
- print(f"\033[92mDEBUG\033[0m: Audio input shape/length: {len(audio)}")
186
- print(f"\033[92mDEBUG\033[0m: Image input type: {type(image)}")
187
- if isinstance(image, np.ndarray):
188
- print(f"\033[92mDEBUG\033[0m: Image input shape: {image.shape}")
189
-
190
  # Handle different audio input formats
191
- if isinstance(audio, tuple) and len(audio) == 2:
192
- # Standard Gradio format: (sample_rate, audio_data)
193
- sample_rate, audio_data = audio
194
- print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}")
195
- print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
196
- elif isinstance(audio, tuple) and len(audio) > 2:
197
- # Sometimes Gradio returns (sample_rate, audio_data, other_info...)
198
- sample_rate, audio_data = audio[0], audio[-1]
199
- print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}")
200
- print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
201
  elif isinstance(audio, str):
202
- # Direct path to audio file
203
  audio_data, sample_rate = torchaudio.load(audio)
204
- print(f"\033[92mDEBUG\033[0m: Loaded audio from path with shape: {audio_data.shape}")
205
  else:
206
  return f"Error: Unsupported audio format. Got {type(audio)}"
207
 
208
- # Create a temporary file path for the audio and image
209
- temp_dir = "temp"
210
- os.makedirs(temp_dir, exist_ok=True)
211
-
212
- temp_audio_path = os.path.join(temp_dir, "temp_audio.wav")
213
- temp_image_path = os.path.join(temp_dir, "temp_image.jpg")
214
-
215
- # Import necessary libraries
216
- from PIL import Image
217
-
218
- # Audio handling - direct processing from the data in memory
219
  if isinstance(audio_data, np.ndarray):
220
- # Convert numpy array to tensor
221
- print(f"\033[92mDEBUG\033[0m: Converting numpy audio with shape {audio_data.shape} to tensor")
222
  audio_tensor = torch.tensor(audio_data).float()
223
-
224
- # Handle different audio dimensions
225
- if audio_data.ndim == 1:
226
- # Single channel audio
227
- audio_tensor = audio_tensor.unsqueeze(0)
228
- elif audio_data.ndim == 2:
229
- # Ensure channels are first dimension
230
- if audio_data.shape[0] > audio_data.shape[1]:
231
- # More rows than columns, probably (samples, channels)
232
- audio_tensor = torch.tensor(audio_data.T).float()
233
  else:
234
- # Already a tensor
235
  audio_tensor = audio_data.float()
236
 
237
- print(f"\033[92mDEBUG\033[0m: Audio tensor shape before processing: {audio_tensor.shape}")
238
-
239
- # Skip saving/loading and process directly
240
  mfcc = app_process_audio_data(audio_tensor, sample_rate)
241
- print(f"\033[92mDEBUG\033[0m: MFCC tensor shape after processing: {mfcc.shape if mfcc is not None else None}")
 
242
 
243
- # Image handling
244
  if isinstance(image, np.ndarray):
245
- print(f"\033[92mDEBUG\033[0m: Converting numpy image with shape {image.shape} to PIL")
246
- pil_image = Image.fromarray(image)
247
- pil_image.save(temp_image_path)
248
- print(f"\033[92mDEBUG\033[0m: Saved image to {temp_image_path}")
249
  elif isinstance(image, str):
250
- # If image is already a path
251
- temp_image_path = image
252
- print(f"\033[92mDEBUG\033[0m: Using provided image path: {temp_image_path}")
253
  else:
254
  return f"Error: Unsupported image format. Got {type(image)}"
255
 
256
- # Process image
257
- print(f"\033[92mDEBUG\033[0m: Loading and preprocessing image from {temp_image_path}")
258
- image_tensor = torchvision.io.read_image(temp_image_path)
259
- print(f"\033[92mDEBUG\033[0m: Loaded image shape: {image_tensor.shape}")
260
  image_tensor = image_tensor.float()
261
  processed_image = process_image_data(image_tensor)
262
- print(f"\033[92mDEBUG\033[0m: Processed image shape: {processed_image.shape if processed_image is not None else None}")
263
-
264
- # Add batch dimension for inference and move to device
265
- if mfcc is not None:
266
- # Ensure mfcc is on the same device as the model
267
- mfcc = mfcc.unsqueeze(0).to(device)
268
- print(f"\033[92mDEBUG\033[0m: Final MFCC shape with batch dimension: {mfcc.shape}, device: {mfcc.device}")
269
-
270
- if processed_image is not None:
271
- # Ensure processed_image is on the same device as the model
272
- processed_image = processed_image.unsqueeze(0).to(device)
273
- print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}, device: {processed_image.device}")
274
-
275
- # Double-check model is on the correct device
276
- try:
277
- param = next(moe_model.parameters())
278
- print(f"\033[92mDEBUG\033[0m: MoE model device: {param.device}")
279
-
280
- # Check individual models
281
- for i, model in enumerate(moe_model.models):
282
- try:
283
- model_param = next(model.parameters())
284
- print(f"\033[92mDEBUG\033[0m: Model {i} device: {model_param.device}")
285
- except StopIteration:
286
- print(f"\033[91mERR!\033[0m: Model {i} has no parameters!")
287
- except StopIteration:
288
- print(f"\033[91mERR!\033[0m: MoE model has no parameters!")
289
-
290
- # Run inference with MoE model
291
- print(f"\033[92mDEBUG\033[0m: Running inference with MoE model on device: {device}")
292
- if mfcc is not None and processed_image is not None:
293
- with torch.no_grad():
294
- brix_value = moe_model(mfcc, processed_image)
295
- print(f"\033[92mDEBUG\033[0m: Prediction successful: {brix_value.item()}")
296
- else:
297
- return "Error: Failed to process inputs. Please check the debug logs."
298
 
299
- # Format the result with a range display
300
- if brix_value is not None:
301
- brix_score = brix_value.item()
302
-
303
- # Create a header with the numerical result
304
- result = f"πŸ‰ Predicted Sugar Content: {brix_score:.1f}Β° Brix πŸ‰\n\n"
305
-
306
- # Add extra info about the MoE model
307
- result += "Using Ensemble of Top-3 Models:\n"
308
- result += "- EfficientNet-B3 + Transformer\n"
309
- result += "- EfficientNet-B0 + Transformer\n"
310
- result += "- ResNet-50 + Transformer\n\n"
311
-
312
- # Add Brix scale visualization
313
- result += "Sugar Content Scale (in Β°Brix):\n"
314
- result += "──────────────────────────────────\n"
315
-
316
- # Create the scale display with Brix ranges
317
- scale_ranges = [
318
- (0, 8, "Low Sugar (< 8Β° Brix)"),
319
- (8, 9, "Mild Sweetness (8-9Β° Brix)"),
320
- (9, 10, "Medium Sweetness (9-10Β° Brix)"),
321
- (10, 11, "Sweet (10-11Β° Brix)"),
322
- (11, 13, "Very Sweet (11-13Β° Brix)")
323
- ]
324
-
325
- # Find which category the prediction falls into
326
- user_category = None
327
- for min_val, max_val, category_name in scale_ranges:
328
- if min_val <= brix_score < max_val:
329
- user_category = category_name
330
- break
331
- if brix_score >= scale_ranges[-1][0]: # Handle edge case
332
- user_category = scale_ranges[-1][2]
333
-
334
- # Display the scale with the user's result highlighted
335
- for min_val, max_val, category_name in scale_ranges:
336
- if category_name == user_category:
337
- result += f"β–Ά {min_val}-{max_val}: {category_name} β—€ (YOUR WATERMELON)\n"
338
- else:
339
- result += f" {min_val}-{max_val}: {category_name}\n"
340
-
341
- result += "──────────────────────────────────\n\n"
342
 
343
- # Add assessment of the watermelon's sugar content
344
- if brix_score < 8:
345
- result += "Assessment: This watermelon has low sugar content. It may taste bland or slightly bitter."
346
- elif brix_score < 9:
347
- result += "Assessment: This watermelon has mild sweetness. Acceptable flavor but not very sweet."
348
- elif brix_score < 10:
349
- result += "Assessment: This watermelon has moderate sugar content. It should have pleasant sweetness."
350
- elif brix_score < 11:
351
- result += "Assessment: This watermelon has good sugar content! It should be sweet and juicy."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  else:
353
- result += "Assessment: This watermelon has excellent sugar content! Perfect choice for maximum sweetness and flavor."
354
-
355
- return result
 
 
 
 
 
 
 
 
 
 
356
  else:
357
- return "Error: Could not predict sugar content. Please try again with different inputs."
358
-
 
359
  except Exception as e:
360
  import traceback
361
  error_msg = f"Error: {str(e)}\n\n"
 
34
  weights: Optional list of weights for each model (None for equal weighting)
35
  """
36
  super(WatermelonMoEModel, self).__init__()
37
+ self.models = []
38
  self.model_configs = model_configs
39
 
40
  # Load each model
 
41
  for config in model_configs:
42
  img_backbone = config["image_backbone"]
43
  audio_backbone = config["audio_backbone"]
 
49
  model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt")
50
  if os.path.exists(model_path):
51
  print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}")
52
+ model.load_state_dict(torch.load(model_path, map_location='cpu'))
 
 
 
 
 
 
 
53
  else:
54
  print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}")
55
  continue
56
+
57
+ model.eval() # Set to evaluation mode
58
+ self.models.append(model)
59
 
 
 
 
 
 
60
  # Set model weights (uniform by default)
61
+ if weights:
62
  assert len(weights) == len(self.models), "Number of weights must match number of models"
63
  self.weights = weights
64
  else:
65
+ self.weights = [1.0 / len(self.models)] * len(self.models) if self.models else [1.0]
66
 
67
+ print(f"\033[92mINFO\033[0m: Loaded {len(self.models)} models for MoE ensemble")
68
  print(f"\033[92mINFO\033[0m: Model weights: {self.weights}")
69
 
70
  def to(self, device):
 
80
  Forward pass through the MoE model.
81
  Returns the weighted average of all model outputs.
82
  """
 
83
  if not self.models:
84
  print(f"\033[91mERR!\033[0m: No models available for inference!")
85
+ return torch.tensor([0.0], device=mfcc.device)
86
 
87
  outputs = []
88
 
 
90
  with torch.no_grad():
91
  for i, model in enumerate(self.models):
92
  output = model(mfcc, image)
93
+ # print the output value
94
+ print(f"\033[92mDEBUG\033[0m: Model {i} output: {output}")
95
  outputs.append(output * self.weights[i])
96
 
97
  # Return weighted average
 
157
  """Function with GPU acceleration to predict watermelon sugar content in Brix using MoE model"""
158
  try:
159
  # Check CUDA availability inside the GPU-decorated function
160
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
161
+ print(f"\033[92mINFO\033[0m: Using device: {device}")
 
 
 
 
162
 
163
  # Load MoE model
164
  moe_model = WatermelonMoEModel(TOP_MODELS, model_dir, weights)
165
+ moe_model = moe_model.to(device) # Move entire model to device
 
166
  moe_model.eval()
167
  print(f"\033[92mINFO\033[0m: Loaded MoE model with {len(moe_model.models)} backbone models")
168
 
 
 
 
 
 
 
 
169
  # Handle different audio input formats
170
+ if isinstance(audio, tuple) and len(audio) >= 2:
171
+ sample_rate, audio_data = audio[0], audio[1] if len(audio) == 2 else audio[-1]
 
 
 
 
 
 
 
 
172
  elif isinstance(audio, str):
 
173
  audio_data, sample_rate = torchaudio.load(audio)
 
174
  else:
175
  return f"Error: Unsupported audio format. Got {type(audio)}"
176
 
177
+ # Convert audio to tensor if needed
 
 
 
 
 
 
 
 
 
 
178
  if isinstance(audio_data, np.ndarray):
 
 
179
  audio_tensor = torch.tensor(audio_data).float()
 
 
 
 
 
 
 
 
 
 
180
  else:
 
181
  audio_tensor = audio_data.float()
182
 
183
+ # Process audio
 
 
184
  mfcc = app_process_audio_data(audio_tensor, sample_rate)
185
+ if mfcc is None:
186
+ return "Error: Failed to process audio input"
187
 
188
+ # Process image
189
  if isinstance(image, np.ndarray):
190
+ image_tensor = torch.from_numpy(image).permute(2, 0, 1) # Convert to CxHxW format
 
 
 
191
  elif isinstance(image, str):
192
+ image_tensor = torchvision.io.read_image(image)
 
 
193
  else:
194
  return f"Error: Unsupported image format. Got {type(image)}"
195
 
 
 
 
 
196
  image_tensor = image_tensor.float()
197
  processed_image = process_image_data(image_tensor)
198
+ if processed_image is None:
199
+ return "Error: Failed to process image input"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
+ # Add batch dimension and move to device
202
+ mfcc = mfcc.unsqueeze(0).to(device)
203
+ processed_image = processed_image.unsqueeze(0).to(device)
204
+
205
+ # Run inference
206
+ with torch.no_grad():
207
+ brix_value = moe_model(mfcc, processed_image)
208
+ prediction = brix_value.item()
209
+ print(f"\033[92mDEBUG\033[0m: Raw prediction: {prediction}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
+ # Ensure prediction is within reasonable bounds (e.g., 6-13 Brix)
212
+ prediction = max(6.0, min(13.0, prediction))
213
+ print(f"\033[92mDEBUG\033[0m: Bounded prediction: {prediction}")
214
+
215
+ # Format the result
216
+ result = f"πŸ‰ Predicted Sugar Content: {prediction:.1f}Β° Brix πŸ‰\n\n"
217
+
218
+ # Add extra info about the MoE model
219
+ result += "Using Ensemble of Top-3 Models:\n"
220
+ result += "- EfficientNet-B3 + Transformer\n"
221
+ result += "- EfficientNet-B0 + Transformer\n"
222
+ result += "- ResNet-50 + Transformer\n\n"
223
+
224
+ # Add Brix scale visualization
225
+ result += "Sugar Content Scale (in Β°Brix):\n"
226
+ result += "──────────────────────────────────\n"
227
+
228
+ # Create the scale display with Brix ranges
229
+ scale_ranges = [
230
+ (0, 8, "Low Sugar (< 8Β° Brix)"),
231
+ (8, 9, "Mild Sweetness (8-9Β° Brix)"),
232
+ (9, 10, "Medium Sweetness (9-10Β° Brix)"),
233
+ (10, 11, "Sweet (10-11Β° Brix)"),
234
+ (11, 13, "Very Sweet (11-13Β° Brix)")
235
+ ]
236
+
237
+ # Find which category the prediction falls into
238
+ user_category = None
239
+ for min_val, max_val, category_name in scale_ranges:
240
+ if min_val <= prediction < max_val:
241
+ user_category = category_name
242
+ break
243
+ if prediction >= scale_ranges[-1][0]: # Handle edge case
244
+ user_category = scale_ranges[-1][2]
245
+
246
+ # Display the scale with the user's result highlighted
247
+ for min_val, max_val, category_name in scale_ranges:
248
+ if category_name == user_category:
249
+ result += f"β–Ά {min_val}-{max_val}: {category_name} β—€ (YOUR WATERMELON)\n"
250
  else:
251
+ result += f" {min_val}-{max_val}: {category_name}\n"
252
+
253
+ result += "──────────────────────────────────\n\n"
254
+
255
+ # Add assessment of the watermelon's sugar content
256
+ if prediction < 8:
257
+ result += "Assessment: This watermelon has low sugar content. It may taste bland or slightly bitter."
258
+ elif prediction < 9:
259
+ result += "Assessment: This watermelon has mild sweetness. Acceptable flavor but not very sweet."
260
+ elif prediction < 10:
261
+ result += "Assessment: This watermelon has moderate sugar content. It should have pleasant sweetness."
262
+ elif prediction < 11:
263
+ result += "Assessment: This watermelon has good sugar content! It should be sweet and juicy."
264
  else:
265
+ result += "Assessment: This watermelon has excellent sugar content! Perfect choice for maximum sweetness and flavor."
266
+
267
+ return result
268
  except Exception as e:
269
  import traceback
270
  error_msg = f"Error: {str(e)}\n\n"
app_moe.py CHANGED
@@ -273,9 +273,19 @@ def predict_sugar_content(audio, image, model_dir="models", weights=None):
273
  print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}, device: {processed_image.device}")
274
 
275
  # Double-check model is on the correct device
276
- print(f"\033[92mDEBUG\033[0m: MoE model device: {next(moe_model.parameters()).device}")
277
- for i, model in enumerate(moe_model.models):
278
- print(f"\033[92mDEBUG\033[0m: Model {i} device: {next(model.parameters()).device}")
 
 
 
 
 
 
 
 
 
 
279
 
280
  # Run inference with MoE model
281
  print(f"\033[92mDEBUG\033[0m: Running inference with MoE model on device: {device}")
 
273
  print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}, device: {processed_image.device}")
274
 
275
  # Double-check model is on the correct device
276
+ try:
277
+ param = next(moe_model.parameters())
278
+ print(f"\033[92mDEBUG\033[0m: MoE model device: {param.device}")
279
+
280
+ # Check individual models
281
+ for i, model in enumerate(moe_model.models):
282
+ try:
283
+ model_param = next(model.parameters())
284
+ print(f"\033[92mDEBUG\033[0m: Model {i} device: {model_param.device}")
285
+ except StopIteration:
286
+ print(f"\033[91mERR!\033[0m: Model {i} has no parameters!")
287
+ except StopIteration:
288
+ print(f"\033[91mERR!\033[0m: MoE model has no parameters!")
289
 
290
  # Run inference with MoE model
291
  print(f"\033[92mDEBUG\033[0m: Running inference with MoE model on device: {device}")