watermelon2 / app.py
Xalphinions's picture
Upload folder using huggingface_hub
13b45d3 verified
raw
history blame
21.8 kB
import os
import sys
import torch
import numpy as np
import gradio as gr
import torchaudio
import torchvision
import spaces
import json
# Add parent directory to path to import preprocess functions
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Import functions from preprocess and model definitions
from preprocess import process_image_data
from evaluate_backbones import WatermelonModelModular, IMAGE_BACKBONES, AUDIO_BACKBONES
# Define the top-performing models based on evaluation
TOP_MODELS = [
{"image_backbone": "efficientnet_b3", "audio_backbone": "transformer"},
{"image_backbone": "efficientnet_b0", "audio_backbone": "transformer"},
{"image_backbone": "resnet50", "audio_backbone": "transformer"}
]
# Define the MoE Model
class WatermelonMoEModel(torch.nn.Module):
def __init__(self, model_configs, model_dir="models", weights=None):
"""
Mixture of Experts model that combines multiple backbone models.
Args:
model_configs: List of dictionaries with 'image_backbone' and 'audio_backbone' keys
model_dir: Directory where model checkpoints are stored
weights: Optional list of weights for each model (None for equal weighting)
"""
super(WatermelonMoEModel, self).__init__()
self.models = torch.nn.ModuleList() # Use ModuleList instead of regular list
self.model_configs = model_configs
# Load each model
loaded_count = 0
for config in model_configs:
img_backbone = config["image_backbone"]
audio_backbone = config["audio_backbone"]
# Initialize model
model = WatermelonModelModular(img_backbone, audio_backbone)
# Load weights
model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt")
if os.path.exists(model_path):
print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}")
try:
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.eval() # Set to evaluation mode
self.models.append(model)
loaded_count += 1
except Exception as e:
print(f"\033[91mERR!\033[0m: Failed to load model from {model_path}: {e}")
continue
else:
print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}")
continue
# Add a dummy parameter if no models were loaded to prevent StopIteration
if loaded_count == 0:
print(f"\033[91mERR!\033[0m: No models were successfully loaded!")
self.dummy_param = torch.nn.Parameter(torch.zeros(1))
# Set model weights (uniform by default)
if weights and loaded_count > 0:
assert len(weights) == len(self.models), "Number of weights must match number of models"
self.weights = weights
else:
self.weights = [1.0 / max(loaded_count, 1)] * max(loaded_count, 1)
print(f"\033[92mINFO\033[0m: Loaded {loaded_count} models for MoE ensemble")
print(f"\033[92mINFO\033[0m: Model weights: {self.weights}")
def to(self, device):
"""
Override to() method to ensure all sub-models are moved to the same device
"""
for model in self.models:
model.to(device)
return super(WatermelonMoEModel, self).to(device)
def forward(self, mfcc, image):
"""
Forward pass through the MoE model.
Returns the weighted average of all model outputs.
"""
# Check if we have models loaded
if not self.models:
print(f"\033[91mERR!\033[0m: No models available for inference!")
return torch.tensor([0.0], device=mfcc.device) # Return a default value
outputs = []
# Get outputs from each model
with torch.no_grad():
for i, model in enumerate(self.models):
output = model(mfcc, image)
outputs.append(output * self.weights[i])
# Return weighted average
return torch.sum(torch.stack(outputs), dim=0)
# Modified version of process_audio_data specifically for the app to handle various tensor shapes
def app_process_audio_data(waveform, sample_rate):
"""Modified version of process_audio_data for the app that handles different tensor dimensions"""
try:
print(f"\033[92mDEBUG\033[0m: Processing audio - Initial shape: {waveform.shape}, Sample rate: {sample_rate}")
# Handle different tensor dimensions
if waveform.dim() == 3:
print(f"\033[92mDEBUG\033[0m: Found 3D tensor, converting to 2D")
# For 3D tensor, take the first item (batch dimension)
waveform = waveform[0]
if waveform.dim() == 2:
# Use the first channel for stereo audio
waveform = waveform[0]
print(f"\033[92mDEBUG\033[0m: Using first channel, new shape: {waveform.shape}")
# Resample to 16kHz if needed
resample_rate = 16000
if sample_rate != resample_rate:
print(f"\033[92mDEBUG\033[0m: Resampling from {sample_rate}Hz to {resample_rate}Hz")
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=resample_rate)(waveform)
# Ensure 3 seconds of audio
if waveform.size(0) < 3 * resample_rate:
print(f"\033[92mDEBUG\033[0m: Padding audio from {waveform.size(0)} to {3 * resample_rate} samples")
waveform = torch.nn.functional.pad(waveform, (0, 3 * resample_rate - waveform.size(0)))
else:
print(f"\033[92mDEBUG\033[0m: Trimming audio from {waveform.size(0)} to {3 * resample_rate} samples")
waveform = waveform[: 3 * resample_rate]
# Apply MFCC transformation
print(f"\033[92mDEBUG\033[0m: Applying MFCC transformation")
mfcc_transform = torchaudio.transforms.MFCC(
sample_rate=resample_rate,
n_mfcc=13,
melkwargs={
"n_fft": 256,
"win_length": 256,
"hop_length": 128,
"n_mels": 40,
}
)
mfcc = mfcc_transform(waveform)
print(f"\033[92mDEBUG\033[0m: MFCC output shape: {mfcc.shape}")
return mfcc
except Exception as e:
import traceback
print(f"\033[91mERR!\033[0m: Error in audio processing: {e}")
print(traceback.format_exc())
return None
# Using the decorator for GPU acceleration
@spaces.GPU
def predict_sugar_content(audio, image, model_dir="models", weights=None):
"""Function with GPU acceleration to predict watermelon sugar content in Brix using MoE model"""
try:
# Check CUDA availability inside the GPU-decorated function
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"\033[92mINFO\033[0m: CUDA is available. Using device: {device}")
else:
device = torch.device("cpu")
print(f"\033[92mINFO\033[0m: CUDA is not available. Using device: {device}")
# Load MoE model
moe_model = WatermelonMoEModel(TOP_MODELS, model_dir, weights)
# Explicitly move the entire model to device
moe_model = moe_model.to(device)
moe_model.eval()
print(f"\033[92mINFO\033[0m: Loaded MoE model with {len(moe_model.models)} backbone models")
# Debug information about input types
print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
print(f"\033[92mDEBUG\033[0m: Audio input shape/length: {len(audio)}")
print(f"\033[92mDEBUG\033[0m: Image input type: {type(image)}")
if isinstance(image, np.ndarray):
print(f"\033[92mDEBUG\033[0m: Image input shape: {image.shape}")
# Handle different audio input formats
if isinstance(audio, tuple) and len(audio) == 2:
# Standard Gradio format: (sample_rate, audio_data)
sample_rate, audio_data = audio
print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}")
print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
elif isinstance(audio, tuple) and len(audio) > 2:
# Sometimes Gradio returns (sample_rate, audio_data, other_info...)
sample_rate, audio_data = audio[0], audio[-1]
print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}")
print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
elif isinstance(audio, str):
# Direct path to audio file
audio_data, sample_rate = torchaudio.load(audio)
print(f"\033[92mDEBUG\033[0m: Loaded audio from path with shape: {audio_data.shape}")
else:
return f"Error: Unsupported audio format. Got {type(audio)}"
# Create a temporary file path for the audio and image
temp_dir = "temp"
os.makedirs(temp_dir, exist_ok=True)
temp_audio_path = os.path.join(temp_dir, "temp_audio.wav")
temp_image_path = os.path.join(temp_dir, "temp_image.jpg")
# Import necessary libraries
from PIL import Image
# Audio handling - direct processing from the data in memory
if isinstance(audio_data, np.ndarray):
# Convert numpy array to tensor
print(f"\033[92mDEBUG\033[0m: Converting numpy audio with shape {audio_data.shape} to tensor")
audio_tensor = torch.tensor(audio_data).float()
# Handle different audio dimensions
if audio_data.ndim == 1:
# Single channel audio
audio_tensor = audio_tensor.unsqueeze(0)
elif audio_data.ndim == 2:
# Ensure channels are first dimension
if audio_data.shape[0] > audio_data.shape[1]:
# More rows than columns, probably (samples, channels)
audio_tensor = torch.tensor(audio_data.T).float()
else:
# Already a tensor
audio_tensor = audio_data.float()
print(f"\033[92mDEBUG\033[0m: Audio tensor shape before processing: {audio_tensor.shape}")
# Skip saving/loading and process directly
mfcc = app_process_audio_data(audio_tensor, sample_rate)
print(f"\033[92mDEBUG\033[0m: MFCC tensor shape after processing: {mfcc.shape if mfcc is not None else None}")
# Image handling
if isinstance(image, np.ndarray):
print(f"\033[92mDEBUG\033[0m: Converting numpy image with shape {image.shape} to PIL")
pil_image = Image.fromarray(image)
pil_image.save(temp_image_path)
print(f"\033[92mDEBUG\033[0m: Saved image to {temp_image_path}")
elif isinstance(image, str):
# If image is already a path
temp_image_path = image
print(f"\033[92mDEBUG\033[0m: Using provided image path: {temp_image_path}")
else:
return f"Error: Unsupported image format. Got {type(image)}"
# Process image
print(f"\033[92mDEBUG\033[0m: Loading and preprocessing image from {temp_image_path}")
image_tensor = torchvision.io.read_image(temp_image_path)
print(f"\033[92mDEBUG\033[0m: Loaded image shape: {image_tensor.shape}")
image_tensor = image_tensor.float()
processed_image = process_image_data(image_tensor)
print(f"\033[92mDEBUG\033[0m: Processed image shape: {processed_image.shape if processed_image is not None else None}")
# Add batch dimension for inference and move to device
if mfcc is not None:
# Ensure mfcc is on the same device as the model
mfcc = mfcc.unsqueeze(0).to(device)
print(f"\033[92mDEBUG\033[0m: Final MFCC shape with batch dimension: {mfcc.shape}, device: {mfcc.device}")
if processed_image is not None:
# Ensure processed_image is on the same device as the model
processed_image = processed_image.unsqueeze(0).to(device)
print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}, device: {processed_image.device}")
# Double-check model is on the correct device
try:
param = next(moe_model.parameters())
print(f"\033[92mDEBUG\033[0m: MoE model device: {param.device}")
# Check individual models
for i, model in enumerate(moe_model.models):
try:
model_param = next(model.parameters())
print(f"\033[92mDEBUG\033[0m: Model {i} device: {model_param.device}")
except StopIteration:
print(f"\033[91mERR!\033[0m: Model {i} has no parameters!")
except StopIteration:
print(f"\033[91mERR!\033[0m: MoE model has no parameters!")
# Run inference with MoE model
print(f"\033[92mDEBUG\033[0m: Running inference with MoE model on device: {device}")
if mfcc is not None and processed_image is not None:
with torch.no_grad():
brix_value = moe_model(mfcc, processed_image)
print(f"\033[92mDEBUG\033[0m: Prediction successful: {brix_value.item()}")
else:
return "Error: Failed to process inputs. Please check the debug logs."
# Format the result with a range display
if brix_value is not None:
brix_score = brix_value.item()
# Create a header with the numerical result
result = f"🍉 Predicted Sugar Content: {brix_score:.1f}° Brix 🍉\n\n"
# Add extra info about the MoE model
result += "Using Ensemble of Top-3 Models:\n"
result += "- EfficientNet-B3 + Transformer\n"
result += "- EfficientNet-B0 + Transformer\n"
result += "- ResNet-50 + Transformer\n\n"
# Add Brix scale visualization
result += "Sugar Content Scale (in °Brix):\n"
result += "──────────────────────────────────\n"
# Create the scale display with Brix ranges
scale_ranges = [
(0, 8, "Low Sugar (< 8° Brix)"),
(8, 9, "Mild Sweetness (8-9° Brix)"),
(9, 10, "Medium Sweetness (9-10° Brix)"),
(10, 11, "Sweet (10-11° Brix)"),
(11, 13, "Very Sweet (11-13° Brix)")
]
# Find which category the prediction falls into
user_category = None
for min_val, max_val, category_name in scale_ranges:
if min_val <= brix_score < max_val:
user_category = category_name
break
if brix_score >= scale_ranges[-1][0]: # Handle edge case
user_category = scale_ranges[-1][2]
# Display the scale with the user's result highlighted
for min_val, max_val, category_name in scale_ranges:
if category_name == user_category:
result += f"▶ {min_val}-{max_val}: {category_name} ◀ (YOUR WATERMELON)\n"
else:
result += f" {min_val}-{max_val}: {category_name}\n"
result += "──────────────────────────────────\n\n"
# Add assessment of the watermelon's sugar content
if brix_score < 8:
result += "Assessment: This watermelon has low sugar content. It may taste bland or slightly bitter."
elif brix_score < 9:
result += "Assessment: This watermelon has mild sweetness. Acceptable flavor but not very sweet."
elif brix_score < 10:
result += "Assessment: This watermelon has moderate sugar content. It should have pleasant sweetness."
elif brix_score < 11:
result += "Assessment: This watermelon has good sugar content! It should be sweet and juicy."
else:
result += "Assessment: This watermelon has excellent sugar content! Perfect choice for maximum sweetness and flavor."
return result
else:
return "Error: Could not predict sugar content. Please try again with different inputs."
except Exception as e:
import traceback
error_msg = f"Error: {str(e)}\n\n"
error_msg += traceback.format_exc()
print(f"\033[91mERR!\033[0m: {error_msg}")
return error_msg
def create_app(model_dir="models", weights=None):
"""Create and launch the Gradio interface"""
# Define the prediction function with model path
def predict_fn(audio, image):
return predict_sugar_content(audio, image, model_dir, weights)
# Create Gradio interface
with gr.Blocks(title="Watermelon Sugar Content Predictor (MoE)", theme=gr.themes.Soft()) as interface:
gr.Markdown("# 🍉 Watermelon Sugar Content Predictor (Ensemble Model)")
gr.Markdown("""
This app predicts the sugar content (in °Brix) of a watermelon based on its sound and appearance.
## What's New
This version uses a Mixture of Experts (MoE) ensemble model that combines the three best-performing models:
- EfficientNet-B3 + Transformer
- EfficientNet-B0 + Transformer
- ResNet-50 + Transformer
The ensemble approach provides more accurate predictions than any single model!
## Instructions:
1. Upload or record an audio of tapping the watermelon
2. Upload or capture an image of the watermelon
3. Click 'Predict' to get the sugar content estimation
""")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(label="Upload or Record Audio", type="numpy")
image_input = gr.Image(label="Upload or Capture Image")
submit_btn = gr.Button("Predict Sugar Content", variant="primary")
with gr.Column():
output = gr.Textbox(label="Prediction Results", lines=15)
submit_btn.click(
fn=predict_fn,
inputs=[audio_input, image_input],
outputs=output
)
gr.Markdown("""
## Tips for best results
- For audio: Tap the watermelon with your knuckle and record the sound
- For image: Take a clear photo of the whole watermelon in good lighting
## About Brix Measurement
Brix (°Bx) is a measurement of sugar content in a solution. For watermelons, higher Brix values indicate sweeter fruit.
The average ripe watermelon has a Brix value between 9-11°.
## About the Mixture of Experts Model
This app uses a Mixture of Experts (MoE) model that combines predictions from multiple neural networks.
Our testing shows the ensemble approach achieves a Mean Absolute Error (MAE) of ~0.22, which is significantly
better than any individual model (best individual model: ~0.36 MAE).
""")
return interface
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Watermelon Sugar Content Prediction App (MoE)")
parser.add_argument(
"--model_dir",
type=str,
default="models",
help="Directory containing the model checkpoints"
)
parser.add_argument(
"--share",
action="store_true",
help="Create a shareable link for the app"
)
parser.add_argument(
"--debug",
action="store_true",
help="Enable verbose debug output"
)
parser.add_argument(
"--weighting",
type=str,
choices=["uniform", "performance"],
default="uniform",
help="How to weight the models (uniform or based on performance)"
)
args = parser.parse_args()
if args.debug:
print(f"\033[92mINFO\033[0m: Debug mode enabled")
# Check if model directory exists
if not os.path.exists(args.model_dir):
print(f"\033[91mERR!\033[0m: Model directory not found at {args.model_dir}")
sys.exit(1)
# Determine weights based on argument
weights = None
if args.weighting == "performance":
# Weights inversely proportional to the MAE (better models get higher weights)
# These are the MAE values from the evaluation results
mae_values = [0.3635, 0.3765, 0.3959] # efficientnet_b3+transformer, efficientnet_b0+transformer, resnet50+transformer
# Convert to weights (inverse of MAE, normalized)
inverse_mae = [1/mae for mae in mae_values]
total = sum(inverse_mae)
weights = [val/total for val in inverse_mae]
print(f"\033[92mINFO\033[0m: Using performance-based weights: {weights}")
else:
print(f"\033[92mINFO\033[0m: Using uniform weights")
# Create and launch the app
app = create_app(args.model_dir, weights)
app.launch(share=args.share)