watermelon2 / app_moe.py
Xalphinions's picture
Upload folder using huggingface_hub
2d3a3b3 verified
import os
import sys
import torch
import numpy as np
import gradio as gr
import torchaudio
import torchvision
import json
# Add parent directory to path to import preprocess functions
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Import functions from preprocess and model definitions
from preprocess import process_image_data
from evaluate_backbones import WatermelonModelModular, IMAGE_BACKBONES, AUDIO_BACKBONES
# Define the top-performing models based on evaluation
TOP_MODELS = [
{"image_backbone": "efficientnet_b3", "audio_backbone": "transformer"},
{"image_backbone": "efficientnet_b0", "audio_backbone": "transformer"},
{"image_backbone": "resnet50", "audio_backbone": "transformer"}
]
# Define the MoE Model
class WatermelonMoEModel(torch.nn.Module):
def __init__(self, model_configs, model_dir="models", weights=None):
"""
Mixture of Experts model that combines multiple backbone models.
Args:
model_configs: List of dictionaries with 'image_backbone' and 'audio_backbone' keys
model_dir: Directory where model checkpoints are stored
weights: Optional list of weights for each model (None for equal weighting)
"""
super(WatermelonMoEModel, self).__init__()
self.models = []
self.model_configs = model_configs
# Load each model
for config in model_configs:
img_backbone = config["image_backbone"]
audio_backbone = config["audio_backbone"]
# Initialize model
model = WatermelonModelModular(img_backbone, audio_backbone)
# Load weights
model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt")
if os.path.exists(model_path):
print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}")
model.load_state_dict(torch.load(model_path, map_location='cpu'))
else:
print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}")
continue
model.eval() # Set to evaluation mode
self.models.append(model)
# Set model weights (uniform by default)
if weights:
assert len(weights) == len(self.models), "Number of weights must match number of models"
self.weights = weights
else:
self.weights = [1.0 / len(self.models)] * len(self.models) if self.models else [1.0]
print(f"\033[92mINFO\033[0m: Loaded {len(self.models)} models for MoE ensemble")
print(f"\033[92mINFO\033[0m: Model weights: {self.weights}")
def to(self, device):
"""
Override to() method to ensure all sub-models are moved to the same device
"""
for model in self.models:
model.to(device)
return super(WatermelonMoEModel, self).to(device)
def forward(self, mfcc, image):
"""
Forward pass through the MoE model.
Returns the weighted average of all model outputs.
"""
if not self.models:
print(f"\033[91mERR!\033[0m: No models available for inference!")
return torch.tensor([0.0], device=mfcc.device)
outputs = []
# Get outputs from each model
with torch.no_grad():
for i, model in enumerate(self.models):
output = model(mfcc, image)
# print the output value
print(f"\033[92mDEBUG\033[0m: Model {i} output: {output}")
outputs.append(output * self.weights[i])
# Return weighted average
return torch.sum(torch.stack(outputs), dim=0)
# Modified version of process_audio_data specifically for the app to handle various tensor shapes
def app_process_audio_data(waveform, sample_rate):
"""Modified version of process_audio_data for the app that handles different tensor dimensions"""
try:
print(f"\033[92mDEBUG\033[0m: Processing audio - Initial shape: {waveform.shape}, Sample rate: {sample_rate}")
# Handle different tensor dimensions
if waveform.dim() == 3:
print(f"\033[92mDEBUG\033[0m: Found 3D tensor, converting to 2D")
# For 3D tensor, take the first item (batch dimension)
waveform = waveform[0]
if waveform.dim() == 2:
# Use the first channel for stereo audio
waveform = waveform[0]
print(f"\033[92mDEBUG\033[0m: Using first channel, new shape: {waveform.shape}")
# Resample to 16kHz if needed
resample_rate = 16000
if sample_rate != resample_rate:
print(f"\033[92mDEBUG\033[0m: Resampling from {sample_rate}Hz to {resample_rate}Hz")
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=resample_rate)(waveform)
# Ensure 3 seconds of audio
if waveform.size(0) < 3 * resample_rate:
print(f"\033[92mDEBUG\033[0m: Padding audio from {waveform.size(0)} to {3 * resample_rate} samples")
waveform = torch.nn.functional.pad(waveform, (0, 3 * resample_rate - waveform.size(0)))
else:
print(f"\033[92mDEBUG\033[0m: Trimming audio from {waveform.size(0)} to {3 * resample_rate} samples")
waveform = waveform[: 3 * resample_rate]
# Apply MFCC transformation
print(f"\033[92mDEBUG\033[0m: Applying MFCC transformation")
mfcc_transform = torchaudio.transforms.MFCC(
sample_rate=resample_rate,
n_mfcc=13,
melkwargs={
"n_fft": 256,
"win_length": 256,
"hop_length": 128,
"n_mels": 40,
}
)
mfcc = mfcc_transform(waveform)
print(f"\033[92mDEBUG\033[0m: MFCC output shape: {mfcc.shape}")
return mfcc
except Exception as e:
import traceback
print(f"\033[91mERR!\033[0m: Error in audio processing: {e}")
print(traceback.format_exc())
return None
# Using the decorator for GPU acceleration
def predict_sugar_content(audio, image, model_dir="models", weights=None):
"""Function with GPU acceleration to predict watermelon sugar content in Brix using MoE model"""
try:
# Check CUDA availability inside the GPU-decorated function
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\033[92mINFO\033[0m: Using device: {device}")
# Load MoE model
moe_model = WatermelonMoEModel(TOP_MODELS, model_dir, weights)
moe_model = moe_model.to(device) # Move entire model to device
moe_model.eval()
print(f"\033[92mINFO\033[0m: Loaded MoE model with {len(moe_model.models)} backbone models")
# Handle different audio input formats
if isinstance(audio, tuple) and len(audio) >= 2:
sample_rate, audio_data = audio[0], audio[1] if len(audio) == 2 else audio[-1]
elif isinstance(audio, str):
audio_data, sample_rate = torchaudio.load(audio)
else:
return f"Error: Unsupported audio format. Got {type(audio)}"
# Convert audio to tensor if needed
if isinstance(audio_data, np.ndarray):
audio_tensor = torch.tensor(audio_data).float()
else:
audio_tensor = audio_data.float()
# Process audio
mfcc = app_process_audio_data(audio_tensor, sample_rate)
if mfcc is None:
return "Error: Failed to process audio input"
# Process image
if isinstance(image, np.ndarray):
image_tensor = torch.from_numpy(image).permute(2, 0, 1) # Convert to CxHxW format
elif isinstance(image, str):
image_tensor = torchvision.io.read_image(image)
else:
return f"Error: Unsupported image format. Got {type(image)}"
image_tensor = image_tensor.float()
processed_image = process_image_data(image_tensor)
if processed_image is None:
return "Error: Failed to process image input"
# Add batch dimension and move to device
mfcc = mfcc.unsqueeze(0).to(device)
processed_image = processed_image.unsqueeze(0).to(device)
# Run inference
with torch.no_grad():
brix_value = moe_model(mfcc, processed_image)
prediction = brix_value.item()
print(f"\033[92mDEBUG\033[0m: Raw prediction: {prediction}")
# Ensure prediction is within reasonable bounds (e.g., 6-13 Brix)
prediction = max(6.0, min(13.0, prediction))
print(f"\033[92mDEBUG\033[0m: Bounded prediction: {prediction}")
# Format the result
result = f"πŸ‰ Predicted Sugar Content: {prediction:.1f}Β° Brix πŸ‰\n\n"
# Add extra info about the MoE model
result += "Using Ensemble of Top-3 Models:\n"
result += "- EfficientNet-B3 + Transformer\n"
result += "- EfficientNet-B0 + Transformer\n"
result += "- ResNet-50 + Transformer\n\n"
# Add Brix scale visualization
result += "Sugar Content Scale (in Β°Brix):\n"
result += "──────────────────────────────────\n"
# Create the scale display with Brix ranges
scale_ranges = [
(0, 8, "Low Sugar (< 8Β° Brix)"),
(8, 9, "Mild Sweetness (8-9Β° Brix)"),
(9, 10, "Medium Sweetness (9-10Β° Brix)"),
(10, 11, "Sweet (10-11Β° Brix)"),
(11, 13, "Very Sweet (11-13Β° Brix)")
]
# Find which category the prediction falls into
user_category = None
for min_val, max_val, category_name in scale_ranges:
if min_val <= prediction < max_val:
user_category = category_name
break
if prediction >= scale_ranges[-1][0]: # Handle edge case
user_category = scale_ranges[-1][2]
# Display the scale with the user's result highlighted
for min_val, max_val, category_name in scale_ranges:
if category_name == user_category:
result += f"β–Ά {min_val}-{max_val}: {category_name} β—€ (YOUR WATERMELON)\n"
else:
result += f" {min_val}-{max_val}: {category_name}\n"
result += "──────────────────────────────────\n\n"
# Add assessment of the watermelon's sugar content
if prediction < 8:
result += "Assessment: This watermelon has low sugar content. It may taste bland or slightly bitter."
elif prediction < 9:
result += "Assessment: This watermelon has mild sweetness. Acceptable flavor but not very sweet."
elif prediction < 10:
result += "Assessment: This watermelon has moderate sugar content. It should have pleasant sweetness."
elif prediction < 11:
result += "Assessment: This watermelon has good sugar content! It should be sweet and juicy."
else:
result += "Assessment: This watermelon has excellent sugar content! Perfect choice for maximum sweetness and flavor."
return result
except Exception as e:
import traceback
error_msg = f"Error: {str(e)}\n\n"
error_msg += traceback.format_exc()
print(f"\033[91mERR!\033[0m: {error_msg}")
return error_msg
def create_app(model_dir="models", weights=None):
"""Create and launch the Gradio interface"""
# Define the prediction function with model path
def predict_fn(audio, image):
return predict_sugar_content(audio, image, model_dir, weights)
# Create Gradio interface
with gr.Blocks(title="Watermelon Sugar Content Predictor (MoE)", theme=gr.themes.Soft()) as interface:
gr.Markdown("# πŸ‰ Watermelon Sugar Content Predictor (Ensemble Model)")
gr.Markdown("""
This app predicts the sugar content (in Β°Brix) of a watermelon based on its sound and appearance.
## What's New
This version uses a Mixture of Experts (MoE) ensemble model that combines the three best-performing models:
- EfficientNet-B3 + Transformer
- EfficientNet-B0 + Transformer
- ResNet-50 + Transformer
The ensemble approach provides more accurate predictions than any single model!
## Instructions:
1. Upload or record an audio of tapping the watermelon
2. Upload or capture an image of the watermelon
3. Click 'Predict' to get the sugar content estimation
""")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(label="Upload or Record Audio", type="numpy")
image_input = gr.Image(label="Upload or Capture Image")
submit_btn = gr.Button("Predict Sugar Content", variant="primary")
with gr.Column():
output = gr.Textbox(label="Prediction Results", lines=15)
submit_btn.click(
fn=predict_fn,
inputs=[audio_input, image_input],
outputs=output
)
gr.Markdown("""
## Tips for best results
- For audio: Tap the watermelon with your knuckle and record the sound
- For image: Take a clear photo of the whole watermelon in good lighting
## About Brix Measurement
Brix (Β°Bx) is a measurement of sugar content in a solution. For watermelons, higher Brix values indicate sweeter fruit.
The average ripe watermelon has a Brix value between 9-11Β°.
## About the Mixture of Experts Model
This app uses a Mixture of Experts (MoE) model that combines predictions from multiple neural networks.
Our testing shows the ensemble approach achieves a Mean Absolute Error (MAE) of ~0.22, which is significantly
better than any individual model (best individual model: ~0.36 MAE).
""")
return interface
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Watermelon Sugar Content Prediction App (MoE)")
parser.add_argument(
"--model_dir",
type=str,
default="models",
help="Directory containing the model checkpoints"
)
parser.add_argument(
"--share",
action="store_true",
help="Create a shareable link for the app"
)
parser.add_argument(
"--debug",
action="store_true",
help="Enable verbose debug output"
)
parser.add_argument(
"--weighting",
type=str,
choices=["uniform", "performance"],
default="uniform",
help="How to weight the models (uniform or based on performance)"
)
args = parser.parse_args()
if args.debug:
print(f"\033[92mINFO\033[0m: Debug mode enabled")
# Check if model directory exists
if not os.path.exists(args.model_dir):
print(f"\033[91mERR!\033[0m: Model directory not found at {args.model_dir}")
sys.exit(1)
# Determine weights based on argument
weights = None
if args.weighting == "performance":
# Weights inversely proportional to the MAE (better models get higher weights)
# These are the MAE values from the evaluation results
mae_values = [0.3635, 0.3765, 0.3959] # efficientnet_b3+transformer, efficientnet_b0+transformer, resnet50+transformer
# Convert to weights (inverse of MAE, normalized)
inverse_mae = [1/mae for mae in mae_values]
total = sum(inverse_mae)
weights = [val/total for val in inverse_mae]
print(f"\033[92mINFO\033[0m: Using performance-based weights: {weights}")
else:
print(f"\033[92mINFO\033[0m: Using uniform weights")
# Create and launch the app
app = create_app(args.model_dir, weights)
app.launch(share=args.share)