import os import sys import torch import numpy as np import gradio as gr import torchaudio import torchvision import json # Add parent directory to path to import preprocess functions sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # Import functions from preprocess and model definitions from preprocess import process_image_data from evaluate_backbones import WatermelonModelModular, IMAGE_BACKBONES, AUDIO_BACKBONES # Define the top-performing models based on evaluation TOP_MODELS = [ {"image_backbone": "efficientnet_b3", "audio_backbone": "transformer"}, {"image_backbone": "efficientnet_b0", "audio_backbone": "transformer"}, {"image_backbone": "resnet50", "audio_backbone": "transformer"} ] # Define the MoE Model class WatermelonMoEModel(torch.nn.Module): def __init__(self, model_configs, model_dir="models", weights=None): """ Mixture of Experts model that combines multiple backbone models. Args: model_configs: List of dictionaries with 'image_backbone' and 'audio_backbone' keys model_dir: Directory where model checkpoints are stored weights: Optional list of weights for each model (None for equal weighting) """ super(WatermelonMoEModel, self).__init__() self.models = [] self.model_configs = model_configs # Load each model for config in model_configs: img_backbone = config["image_backbone"] audio_backbone = config["audio_backbone"] # Initialize model model = WatermelonModelModular(img_backbone, audio_backbone) # Load weights model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt") if os.path.exists(model_path): print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}") model.load_state_dict(torch.load(model_path, map_location='cpu')) else: print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}") continue model.eval() # Set to evaluation mode self.models.append(model) # Set model weights (uniform by default) if weights: assert len(weights) == len(self.models), "Number of weights must match number of models" self.weights = weights else: self.weights = [1.0 / len(self.models)] * len(self.models) if self.models else [1.0] print(f"\033[92mINFO\033[0m: Loaded {len(self.models)} models for MoE ensemble") print(f"\033[92mINFO\033[0m: Model weights: {self.weights}") def to(self, device): """ Override to() method to ensure all sub-models are moved to the same device """ for model in self.models: model.to(device) return super(WatermelonMoEModel, self).to(device) def forward(self, mfcc, image): """ Forward pass through the MoE model. Returns the weighted average of all model outputs. """ if not self.models: print(f"\033[91mERR!\033[0m: No models available for inference!") return torch.tensor([0.0], device=mfcc.device) outputs = [] # Get outputs from each model with torch.no_grad(): for i, model in enumerate(self.models): output = model(mfcc, image) # print the output value print(f"\033[92mDEBUG\033[0m: Model {i} output: {output}") outputs.append(output * self.weights[i]) # Return weighted average return torch.sum(torch.stack(outputs), dim=0) # Modified version of process_audio_data specifically for the app to handle various tensor shapes def app_process_audio_data(waveform, sample_rate): """Modified version of process_audio_data for the app that handles different tensor dimensions""" try: print(f"\033[92mDEBUG\033[0m: Processing audio - Initial shape: {waveform.shape}, Sample rate: {sample_rate}") # Handle different tensor dimensions if waveform.dim() == 3: print(f"\033[92mDEBUG\033[0m: Found 3D tensor, converting to 2D") # For 3D tensor, take the first item (batch dimension) waveform = waveform[0] if waveform.dim() == 2: # Use the first channel for stereo audio waveform = waveform[0] print(f"\033[92mDEBUG\033[0m: Using first channel, new shape: {waveform.shape}") # Resample to 16kHz if needed resample_rate = 16000 if sample_rate != resample_rate: print(f"\033[92mDEBUG\033[0m: Resampling from {sample_rate}Hz to {resample_rate}Hz") waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=resample_rate)(waveform) # Ensure 3 seconds of audio if waveform.size(0) < 3 * resample_rate: print(f"\033[92mDEBUG\033[0m: Padding audio from {waveform.size(0)} to {3 * resample_rate} samples") waveform = torch.nn.functional.pad(waveform, (0, 3 * resample_rate - waveform.size(0))) else: print(f"\033[92mDEBUG\033[0m: Trimming audio from {waveform.size(0)} to {3 * resample_rate} samples") waveform = waveform[: 3 * resample_rate] # Apply MFCC transformation print(f"\033[92mDEBUG\033[0m: Applying MFCC transformation") mfcc_transform = torchaudio.transforms.MFCC( sample_rate=resample_rate, n_mfcc=13, melkwargs={ "n_fft": 256, "win_length": 256, "hop_length": 128, "n_mels": 40, } ) mfcc = mfcc_transform(waveform) print(f"\033[92mDEBUG\033[0m: MFCC output shape: {mfcc.shape}") return mfcc except Exception as e: import traceback print(f"\033[91mERR!\033[0m: Error in audio processing: {e}") print(traceback.format_exc()) return None # Using the decorator for GPU acceleration def predict_sugar_content(audio, image, model_dir="models", weights=None): """Function with GPU acceleration to predict watermelon sugar content in Brix using MoE model""" try: # Check CUDA availability inside the GPU-decorated function device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"\033[92mINFO\033[0m: Using device: {device}") # Load MoE model moe_model = WatermelonMoEModel(TOP_MODELS, model_dir, weights) moe_model = moe_model.to(device) # Move entire model to device moe_model.eval() print(f"\033[92mINFO\033[0m: Loaded MoE model with {len(moe_model.models)} backbone models") # Handle different audio input formats if isinstance(audio, tuple) and len(audio) >= 2: sample_rate, audio_data = audio[0], audio[1] if len(audio) == 2 else audio[-1] elif isinstance(audio, str): audio_data, sample_rate = torchaudio.load(audio) else: return f"Error: Unsupported audio format. Got {type(audio)}" # Convert audio to tensor if needed if isinstance(audio_data, np.ndarray): audio_tensor = torch.tensor(audio_data).float() else: audio_tensor = audio_data.float() # Process audio mfcc = app_process_audio_data(audio_tensor, sample_rate) if mfcc is None: return "Error: Failed to process audio input" # Process image if isinstance(image, np.ndarray): image_tensor = torch.from_numpy(image).permute(2, 0, 1) # Convert to CxHxW format elif isinstance(image, str): image_tensor = torchvision.io.read_image(image) else: return f"Error: Unsupported image format. Got {type(image)}" image_tensor = image_tensor.float() processed_image = process_image_data(image_tensor) if processed_image is None: return "Error: Failed to process image input" # Add batch dimension and move to device mfcc = mfcc.unsqueeze(0).to(device) processed_image = processed_image.unsqueeze(0).to(device) # Run inference with torch.no_grad(): brix_value = moe_model(mfcc, processed_image) prediction = brix_value.item() print(f"\033[92mDEBUG\033[0m: Raw prediction: {prediction}") # Ensure prediction is within reasonable bounds (e.g., 6-13 Brix) prediction = max(6.0, min(13.0, prediction)) print(f"\033[92mDEBUG\033[0m: Bounded prediction: {prediction}") # Format the result result = f"πŸ‰ Predicted Sugar Content: {prediction:.1f}Β° Brix πŸ‰\n\n" # Add extra info about the MoE model result += "Using Ensemble of Top-3 Models:\n" result += "- EfficientNet-B3 + Transformer\n" result += "- EfficientNet-B0 + Transformer\n" result += "- ResNet-50 + Transformer\n\n" # Add Brix scale visualization result += "Sugar Content Scale (in Β°Brix):\n" result += "──────────────────────────────────\n" # Create the scale display with Brix ranges scale_ranges = [ (0, 8, "Low Sugar (< 8Β° Brix)"), (8, 9, "Mild Sweetness (8-9Β° Brix)"), (9, 10, "Medium Sweetness (9-10Β° Brix)"), (10, 11, "Sweet (10-11Β° Brix)"), (11, 13, "Very Sweet (11-13Β° Brix)") ] # Find which category the prediction falls into user_category = None for min_val, max_val, category_name in scale_ranges: if min_val <= prediction < max_val: user_category = category_name break if prediction >= scale_ranges[-1][0]: # Handle edge case user_category = scale_ranges[-1][2] # Display the scale with the user's result highlighted for min_val, max_val, category_name in scale_ranges: if category_name == user_category: result += f"β–Ά {min_val}-{max_val}: {category_name} β—€ (YOUR WATERMELON)\n" else: result += f" {min_val}-{max_val}: {category_name}\n" result += "──────────────────────────────────\n\n" # Add assessment of the watermelon's sugar content if prediction < 8: result += "Assessment: This watermelon has low sugar content. It may taste bland or slightly bitter." elif prediction < 9: result += "Assessment: This watermelon has mild sweetness. Acceptable flavor but not very sweet." elif prediction < 10: result += "Assessment: This watermelon has moderate sugar content. It should have pleasant sweetness." elif prediction < 11: result += "Assessment: This watermelon has good sugar content! It should be sweet and juicy." else: result += "Assessment: This watermelon has excellent sugar content! Perfect choice for maximum sweetness and flavor." return result except Exception as e: import traceback error_msg = f"Error: {str(e)}\n\n" error_msg += traceback.format_exc() print(f"\033[91mERR!\033[0m: {error_msg}") return error_msg def create_app(model_dir="models", weights=None): """Create and launch the Gradio interface""" # Define the prediction function with model path def predict_fn(audio, image): return predict_sugar_content(audio, image, model_dir, weights) # Create Gradio interface with gr.Blocks(title="Watermelon Sugar Content Predictor (MoE)", theme=gr.themes.Soft()) as interface: gr.Markdown("# πŸ‰ Watermelon Sugar Content Predictor (Ensemble Model)") gr.Markdown(""" This app predicts the sugar content (in Β°Brix) of a watermelon based on its sound and appearance. ## What's New This version uses a Mixture of Experts (MoE) ensemble model that combines the three best-performing models: - EfficientNet-B3 + Transformer - EfficientNet-B0 + Transformer - ResNet-50 + Transformer The ensemble approach provides more accurate predictions than any single model! ## Instructions: 1. Upload or record an audio of tapping the watermelon 2. Upload or capture an image of the watermelon 3. Click 'Predict' to get the sugar content estimation """) with gr.Row(): with gr.Column(): audio_input = gr.Audio(label="Upload or Record Audio", type="numpy") image_input = gr.Image(label="Upload or Capture Image") submit_btn = gr.Button("Predict Sugar Content", variant="primary") with gr.Column(): output = gr.Textbox(label="Prediction Results", lines=15) submit_btn.click( fn=predict_fn, inputs=[audio_input, image_input], outputs=output ) gr.Markdown(""" ## Tips for best results - For audio: Tap the watermelon with your knuckle and record the sound - For image: Take a clear photo of the whole watermelon in good lighting ## About Brix Measurement Brix (Β°Bx) is a measurement of sugar content in a solution. For watermelons, higher Brix values indicate sweeter fruit. The average ripe watermelon has a Brix value between 9-11Β°. ## About the Mixture of Experts Model This app uses a Mixture of Experts (MoE) model that combines predictions from multiple neural networks. Our testing shows the ensemble approach achieves a Mean Absolute Error (MAE) of ~0.22, which is significantly better than any individual model (best individual model: ~0.36 MAE). """) return interface if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Watermelon Sugar Content Prediction App (MoE)") parser.add_argument( "--model_dir", type=str, default="models", help="Directory containing the model checkpoints" ) parser.add_argument( "--share", action="store_true", help="Create a shareable link for the app" ) parser.add_argument( "--debug", action="store_true", help="Enable verbose debug output" ) parser.add_argument( "--weighting", type=str, choices=["uniform", "performance"], default="uniform", help="How to weight the models (uniform or based on performance)" ) args = parser.parse_args() if args.debug: print(f"\033[92mINFO\033[0m: Debug mode enabled") # Check if model directory exists if not os.path.exists(args.model_dir): print(f"\033[91mERR!\033[0m: Model directory not found at {args.model_dir}") sys.exit(1) # Determine weights based on argument weights = None if args.weighting == "performance": # Weights inversely proportional to the MAE (better models get higher weights) # These are the MAE values from the evaluation results mae_values = [0.3635, 0.3765, 0.3959] # efficientnet_b3+transformer, efficientnet_b0+transformer, resnet50+transformer # Convert to weights (inverse of MAE, normalized) inverse_mae = [1/mae for mae in mae_values] total = sum(inverse_mae) weights = [val/total for val in inverse_mae] print(f"\033[92mINFO\033[0m: Using performance-based weights: {weights}") else: print(f"\033[92mINFO\033[0m: Using uniform weights") # Create and launch the app app = create_app(args.model_dir, weights) app.launch(share=args.share)