watermelon2 / app.py
Xalphinions's picture
Upload folder using huggingface_hub
5900417 verified
raw
history blame
13.7 kB
import os
import sys
import torch
import numpy as np
import gradio as gr
import torchaudio
import torchvision
# Import Gradio Spaces GPU decorator
try:
from gradio import spaces
HAS_SPACES = True
print("\033[92mINFO\033[0m: Gradio Spaces detected, GPU acceleration will be enabled")
except ImportError:
HAS_SPACES = False
print("\033[93mWARN\033[0m: gradio.spaces not available, running without GPU optimization")
# Add parent directory to path to import preprocess functions
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Import functions from infer_watermelon.py and train_watermelon for the model
from train_watermelon import WatermelonModel
# Modified version of process_audio_data specifically for the app to handle various tensor shapes
def app_process_audio_data(waveform, sample_rate):
"""Modified version of process_audio_data for the app that handles different tensor dimensions"""
try:
print(f"\033[92mDEBUG\033[0m: Processing audio - Initial shape: {waveform.shape}, Sample rate: {sample_rate}")
# Handle different tensor dimensions
if waveform.dim() == 3:
print(f"\033[92mDEBUG\033[0m: Found 3D tensor, converting to 2D")
# For 3D tensor, take the first item (batch dimension)
waveform = waveform[0]
if waveform.dim() == 2:
# Use the first channel for stereo audio
waveform = waveform[0]
print(f"\033[92mDEBUG\033[0m: Using first channel, new shape: {waveform.shape}")
# Resample to 16kHz if needed
resample_rate = 16000
if sample_rate != resample_rate:
print(f"\033[92mDEBUG\033[0m: Resampling from {sample_rate}Hz to {resample_rate}Hz")
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=resample_rate)(waveform)
# Ensure 3 seconds of audio
if waveform.size(0) < 3 * resample_rate:
print(f"\033[92mDEBUG\033[0m: Padding audio from {waveform.size(0)} to {3 * resample_rate} samples")
waveform = torch.nn.functional.pad(waveform, (0, 3 * resample_rate - waveform.size(0)))
else:
print(f"\033[92mDEBUG\033[0m: Trimming audio from {waveform.size(0)} to {3 * resample_rate} samples")
waveform = waveform[: 3 * resample_rate]
# Apply MFCC transformation
print(f"\033[92mDEBUG\033[0m: Applying MFCC transformation")
mfcc_transform = torchaudio.transforms.MFCC(
sample_rate=resample_rate,
n_mfcc=13,
melkwargs={
"n_fft": 256,
"win_length": 256,
"hop_length": 128,
"n_mels": 40,
}
)
mfcc = mfcc_transform(waveform)
print(f"\033[92mDEBUG\033[0m: MFCC output shape: {mfcc.shape}")
return mfcc
except Exception as e:
import traceback
print(f"\033[91mERR!\033[0m: Error in audio processing: {e}")
print(traceback.format_exc())
return None
# Similarly for images, but let's import the original one
from preprocess import process_image_data
# Define prediction function
def predict_sweetness(audio, image, model_path):
"""Predict sweetness of a watermelon from audio and image input"""
try:
# Now check CUDA availability inside the GPU-decorated function
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"\033[92mINFO\033[0m: CUDA is available. Using device: {device}")
else:
device = torch.device("cpu")
print(f"\033[92mINFO\033[0m: CUDA is not available. Using device: {device}")
# Load model inside the function to ensure it's on the correct device
model = WatermelonModel().to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
print(f"\033[92mINFO\033[0m: Loaded model from {model_path}")
# Debug information about input types
print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
print(f"\033[92mDEBUG\033[0m: Audio input shape/length: {len(audio)}")
print(f"\033[92mDEBUG\033[0m: Image input type: {type(image)}")
if isinstance(image, np.ndarray):
print(f"\033[92mDEBUG\033[0m: Image input shape: {image.shape}")
# Handle different audio input formats
if isinstance(audio, tuple) and len(audio) == 2:
# Standard Gradio format: (sample_rate, audio_data)
sample_rate, audio_data = audio
print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}")
print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
elif isinstance(audio, tuple) and len(audio) > 2:
# Sometimes Gradio returns (sample_rate, audio_data, other_info...)
sample_rate, audio_data = audio[0], audio[-1]
print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}")
print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
elif isinstance(audio, str):
# Direct path to audio file
audio_data, sample_rate = torchaudio.load(audio)
print(f"\033[92mDEBUG\033[0m: Loaded audio from path with shape: {audio_data.shape}")
else:
return f"Error: Unsupported audio format. Got {type(audio)}"
# Create a temporary file path for the audio and image
temp_dir = "temp"
os.makedirs(temp_dir, exist_ok=True)
temp_audio_path = os.path.join(temp_dir, "temp_audio.wav")
temp_image_path = os.path.join(temp_dir, "temp_image.jpg")
# Import necessary libraries
from PIL import Image
# Audio handling - direct processing from the data in memory
if isinstance(audio_data, np.ndarray):
# Convert numpy array to tensor
print(f"\033[92mDEBUG\033[0m: Converting numpy audio with shape {audio_data.shape} to tensor")
audio_tensor = torch.tensor(audio_data).float()
# Handle different audio dimensions
if audio_data.ndim == 1:
# Single channel audio
audio_tensor = audio_tensor.unsqueeze(0)
elif audio_data.ndim == 2:
# Ensure channels are first dimension
if audio_data.shape[0] > audio_data.shape[1]:
# More rows than columns, probably (samples, channels)
audio_tensor = torch.tensor(audio_data.T).float()
else:
# Already a tensor
audio_tensor = audio_data.float()
print(f"\033[92mDEBUG\033[0m: Audio tensor shape before processing: {audio_tensor.shape}")
# Skip saving/loading and process directly
mfcc = app_process_audio_data(audio_tensor, sample_rate)
print(f"\033[92mDEBUG\033[0m: MFCC tensor shape after processing: {mfcc.shape if mfcc is not None else None}")
# Image handling
if isinstance(image, np.ndarray):
print(f"\033[92mDEBUG\033[0m: Converting numpy image with shape {image.shape} to PIL")
pil_image = Image.fromarray(image)
pil_image.save(temp_image_path)
print(f"\033[92mDEBUG\033[0m: Saved image to {temp_image_path}")
elif isinstance(image, str):
# If image is already a path
temp_image_path = image
print(f"\033[92mDEBUG\033[0m: Using provided image path: {temp_image_path}")
else:
return f"Error: Unsupported image format. Got {type(image)}"
# Process image
print(f"\033[92mDEBUG\033[0m: Loading and preprocessing image from {temp_image_path}")
image_tensor = torchvision.io.read_image(temp_image_path)
print(f"\033[92mDEBUG\033[0m: Loaded image shape: {image_tensor.shape}")
image_tensor = image_tensor.float()
processed_image = process_image_data(image_tensor)
print(f"\033[92mDEBUG\033[0m: Processed image shape: {processed_image.shape if processed_image is not None else None}")
# Add batch dimension for inference and move to device
if mfcc is not None:
mfcc = mfcc.unsqueeze(0).to(device)
print(f"\033[92mDEBUG\033[0m: Final MFCC shape with batch dimension: {mfcc.shape}")
if processed_image is not None:
processed_image = processed_image.unsqueeze(0).to(device)
print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}")
# Run inference
print(f"\033[92mDEBUG\033[0m: Running inference on device: {device}")
if mfcc is not None and processed_image is not None:
with torch.no_grad():
sweetness = model(mfcc, processed_image)
print(f"\033[92mDEBUG\033[0m: Prediction successful: {sweetness.item()}")
else:
return "Error: Failed to process inputs. Please check the debug logs."
# Format the result
if sweetness is not None:
result = f"Predicted Sweetness: {sweetness.item():.2f}/13"
# Add a qualitative description
if sweetness.item() < 9:
result += "\n\nThis watermelon is not very sweet. You might want to choose another one."
elif sweetness.item() < 10:
result += "\n\nThis watermelon has moderate sweetness."
elif sweetness.item() < 11:
result += "\n\nThis watermelon is sweet! A good choice."
else:
result += "\n\nThis watermelon is very sweet! Excellent choice!"
return result
else:
return "Error: Could not predict sweetness. Please try again with different inputs."
except Exception as e:
import traceback
error_msg = f"Error: {str(e)}\n\n"
error_msg += traceback.format_exc()
print(f"\033[91mERR!\033[0m: {error_msg}")
return error_msg
# Apply GPU decorator if available in Gradio Spaces environment
if HAS_SPACES:
predict_sweetness_gpu = spaces.GPU(predict_sweetness)
print("\033[92mINFO\033[0m: GPU optimization enabled for prediction function")
else:
predict_sweetness_gpu = predict_sweetness
def create_app(model_path):
"""Create and launch the Gradio interface"""
# Define the prediction function with model path
def predict_fn(audio, image):
if HAS_SPACES:
# Use GPU-optimized function if available
return predict_sweetness_gpu(audio, image, model_path)
else:
# Use regular function otherwise
return predict_sweetness(audio, image, model_path)
# Create Gradio interface
with gr.Blocks(title="Watermelon Sweetness Predictor", theme=gr.themes.Soft()) as interface:
gr.Markdown("# 🍉 Watermelon Sweetness Predictor")
gr.Markdown("""
This app predicts the sweetness of a watermelon based on its sound and appearance.
## Instructions:
1. Upload or record an audio of tapping the watermelon
2. Upload or capture an image of the watermelon
3. Click 'Predict' to get the sweetness estimation
""")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(label="Upload or Record Audio", type="numpy")
image_input = gr.Image(label="Upload or Capture Image")
submit_btn = gr.Button("Predict Sweetness", variant="primary")
with gr.Column():
output = gr.Textbox(label="Prediction Results", lines=6)
submit_btn.click(
fn=predict_fn,
inputs=[audio_input, image_input],
outputs=output
)
gr.Markdown("""
## How it works
The app uses a deep learning model that combines:
- Audio analysis using MFCC features and LSTM neural network
- Image analysis using ResNet-50 convolutional neural network
The model was trained on a dataset of watermelons with known sweetness values.
## Tips for best results
- For audio: Tap the watermelon with your knuckle and record the sound
- For image: Take a clear photo of the whole watermelon in good lighting
""")
return interface
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Watermelon Sweetness Prediction App")
parser.add_argument(
"--model_path",
type=str,
default="models/watermelon_model_final.pt",
help="Path to the trained model file"
)
parser.add_argument(
"--share",
action="store_true",
help="Create a shareable link for the app"
)
parser.add_argument(
"--debug",
action="store_true",
help="Enable verbose debug output"
)
args = parser.parse_args()
if args.debug:
print(f"\033[92mINFO\033[0m: Debug mode enabled")
# Check if model exists
if not os.path.exists(args.model_path):
print(f"\033[91mERR!\033[0m: Model not found at {args.model_path}")
print("\033[92mINFO\033[0m: Please train a model first or provide a valid model path")
sys.exit(1)
# Create and launch the app
app = create_app(args.model_path)
app.launch(share=args.share)