import os import sys import torch import numpy as np import gradio as gr import torchaudio import torchvision # Add parent directory to path to import preprocess functions sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # Import functions from infer_watermelon.py from infer_watermelon import load_model # Modified version of process_audio_data specifically for the app to handle various tensor shapes def app_process_audio_data(waveform, sample_rate): """Modified version of process_audio_data for the app that handles different tensor dimensions""" try: print(f"\033[92mDEBUG\033[0m: Processing audio - Initial shape: {waveform.shape}, Sample rate: {sample_rate}") # Handle different tensor dimensions if waveform.dim() == 3: print(f"\033[92mDEBUG\033[0m: Found 3D tensor, converting to 2D") # For 3D tensor, take the first item (batch dimension) waveform = waveform[0] if waveform.dim() == 2: # Use the first channel for stereo audio waveform = waveform[0] print(f"\033[92mDEBUG\033[0m: Using first channel, new shape: {waveform.shape}") # Resample to 16kHz if needed resample_rate = 16000 if sample_rate != resample_rate: print(f"\033[92mDEBUG\033[0m: Resampling from {sample_rate}Hz to {resample_rate}Hz") waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=resample_rate)(waveform) # Ensure 3 seconds of audio if waveform.size(0) < 3 * resample_rate: print(f"\033[92mDEBUG\033[0m: Padding audio from {waveform.size(0)} to {3 * resample_rate} samples") waveform = torch.nn.functional.pad(waveform, (0, 3 * resample_rate - waveform.size(0))) else: print(f"\033[92mDEBUG\033[0m: Trimming audio from {waveform.size(0)} to {3 * resample_rate} samples") waveform = waveform[: 3 * resample_rate] # Apply MFCC transformation print(f"\033[92mDEBUG\033[0m: Applying MFCC transformation") mfcc_transform = torchaudio.transforms.MFCC( sample_rate=resample_rate, n_mfcc=13, melkwargs={ "n_fft": 256, "win_length": 256, "hop_length": 128, "n_mels": 40, } ) mfcc = mfcc_transform(waveform) print(f"\033[92mDEBUG\033[0m: MFCC output shape: {mfcc.shape}") return mfcc except Exception as e: import traceback print(f"\033[91mERR!\033[0m: Error in audio processing: {e}") print(traceback.format_exc()) return None # Similarly for images, but let's import the original one from preprocess import process_image_data def init_model(model_path): """Initialize the model for inference""" model, device = load_model(model_path) return model, device def predict_sweetness(audio, image, model, device): """Predict sweetness of a watermelon from audio and image input""" try: # Debug information about input types print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}") print(f"\033[92mDEBUG\033[0m: Audio input shape/length: {len(audio)}") print(f"\033[92mDEBUG\033[0m: Image input type: {type(image)}") if isinstance(image, np.ndarray): print(f"\033[92mDEBUG\033[0m: Image input shape: {image.shape}") # Handle different audio input formats if isinstance(audio, tuple) and len(audio) == 2: # Standard Gradio format: (sample_rate, audio_data) sample_rate, audio_data = audio print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}") print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}") elif isinstance(audio, tuple) and len(audio) > 2: # Sometimes Gradio returns (sample_rate, audio_data, other_info...) sample_rate, audio_data = audio[0], audio[-1] print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}") print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}") elif isinstance(audio, str): # Direct path to audio file import torchaudio audio_data, sample_rate = torchaudio.load(audio) print(f"\033[92mDEBUG\033[0m: Loaded audio from path with shape: {audio_data.shape}") else: return f"Error: Unsupported audio format. Got {type(audio)}" # Create a temporary file path for the audio and image temp_dir = "temp" os.makedirs(temp_dir, exist_ok=True) temp_audio_path = os.path.join(temp_dir, "temp_audio.wav") temp_image_path = os.path.join(temp_dir, "temp_image.jpg") # Import necessary libraries import torchaudio import torchvision import torchvision.transforms.functional as F from PIL import Image # Audio handling - direct processing from the data in memory if isinstance(audio_data, np.ndarray): # Convert numpy array to tensor print(f"\033[92mDEBUG\033[0m: Converting numpy audio with shape {audio_data.shape} to tensor") audio_tensor = torch.tensor(audio_data).float() # Handle different audio dimensions if audio_data.ndim == 1: # Single channel audio audio_tensor = audio_tensor.unsqueeze(0) elif audio_data.ndim == 2: # Ensure channels are first dimension if audio_data.shape[0] > audio_data.shape[1]: # More rows than columns, probably (samples, channels) audio_tensor = torch.tensor(audio_data.T).float() else: # Already a tensor audio_tensor = audio_data.float() print(f"\033[92mDEBUG\033[0m: Audio tensor shape before processing: {audio_tensor.shape}") # Skip saving/loading and process directly mfcc = app_process_audio_data(audio_tensor, sample_rate) print(f"\033[92mDEBUG\033[0m: MFCC tensor shape after processing: {mfcc.shape if mfcc is not None else None}") # Image handling if isinstance(image, np.ndarray): print(f"\033[92mDEBUG\033[0m: Converting numpy image with shape {image.shape} to PIL") pil_image = Image.fromarray(image) pil_image.save(temp_image_path) print(f"\033[92mDEBUG\033[0m: Saved image to {temp_image_path}") elif isinstance(image, str): # If image is already a path temp_image_path = image print(f"\033[92mDEBUG\033[0m: Using provided image path: {temp_image_path}") else: return f"Error: Unsupported image format. Got {type(image)}" # Process image print(f"\033[92mDEBUG\033[0m: Loading and preprocessing image from {temp_image_path}") image_tensor = torchvision.io.read_image(temp_image_path) print(f"\033[92mDEBUG\033[0m: Loaded image shape: {image_tensor.shape}") image_tensor = image_tensor.float() processed_image = process_image_data(image_tensor) print(f"\033[92mDEBUG\033[0m: Processed image shape: {processed_image.shape if processed_image is not None else None}") # Add batch dimension for inference if mfcc is not None: mfcc = mfcc.unsqueeze(0).to(device) print(f"\033[92mDEBUG\033[0m: Final MFCC shape with batch dimension: {mfcc.shape}") if processed_image is not None: processed_image = processed_image.unsqueeze(0).to(device) print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}") # Run inference print(f"\033[92mDEBUG\033[0m: Running inference") if mfcc is not None and processed_image is not None: with torch.no_grad(): sweetness = model(mfcc, processed_image) print(f"\033[92mDEBUG\033[0m: Prediction successful: {sweetness.item()}") else: return "Error: Failed to process inputs. Please check the debug logs." # Format the result if sweetness is not None: result = f"Predicted Sweetness: {sweetness.item():.2f}/13" # Add a qualitative description if sweetness.item() < 9: result += "\n\nThis watermelon is not very sweet. You might want to choose another one." elif sweetness.item() < 10: result += "\n\nThis watermelon has moderate sweetness." elif sweetness.item() < 11: result += "\n\nThis watermelon is sweet! A good choice." else: result += "\n\nThis watermelon is very sweet! Excellent choice!" return result else: return "Error: Could not predict sweetness. Please try again with different inputs." except Exception as e: import traceback error_msg = f"Error: {str(e)}\n\n" error_msg += traceback.format_exc() print(f"\033[91mERR!\033[0m: {error_msg}") return error_msg def create_app(model_path): """Create and launch the Gradio interface""" # Initialize model model, device = init_model(model_path) # Define the prediction function with model and device def predict_fn(audio, image): return predict_sweetness(audio, image, model, device) # Create Gradio interface with gr.Blocks(title="Watermelon Sweetness Predictor") as interface: gr.Markdown("# 🍉 Watermelon Sweetness Predictor") gr.Markdown(""" This app predicts the sweetness of a watermelon based on its sound and appearance. ## Instructions: 1. Upload or record an audio of tapping the watermelon 2. Upload or capture an image of the watermelon 3. Click 'Submit' to get the predicted sweetness """) with gr.Row(): with gr.Column(): audio_input = gr.Audio(label="Upload or Record Audio", type="numpy") image_input = gr.Image(label="Upload or Capture Image") submit_btn = gr.Button("Predict Sweetness", variant="primary") with gr.Column(): output = gr.Textbox(label="Prediction Results", lines=6) submit_btn.click( fn=predict_fn, inputs=[audio_input, image_input], outputs=output ) gr.Markdown(""" ## How it works The app uses a deep learning model that combines: - Audio analysis using MFCC features and LSTM neural network - Image analysis using ResNet-50 convolutional neural network The model was trained on a dataset of watermelons with known sweetness values. """) return interface if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Watermelon Sweetness Prediction App") parser.add_argument( "--model_path", type=str, default="models/watermelon_model_final.pt", help="Path to the trained model file" ) parser.add_argument( "--share", action="store_true", help="Create a shareable link for the app" ) parser.add_argument( "--debug", action="store_true", help="Enable verbose debug output" ) args = parser.parse_args() if args.debug: print(f"\033[92mINFO\033[0m: Debug mode enabled") # Check if model exists if not os.path.exists(args.model_path): print(f"\033[91mERR!\033[0m: Model not found at {args.model_path}") print("\033[92mINFO\033[0m: Please train a model first or provide a valid model path") sys.exit(1) # Create and launch the app app = create_app(args.model_path) app.launch(share=args.share)