Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import torch | |
import numpy as np | |
import gradio as gr | |
import torchaudio | |
import torchvision | |
import spaces | |
# # Import Gradio Spaces GPU decorator | |
# try: | |
# from gradio import spaces | |
# HAS_SPACES = True | |
# print("\033[92mINFO\033[0m: Gradio Spaces detected, GPU acceleration will be enabled") | |
# except ImportError: | |
# HAS_SPACES = False | |
# print("\033[93mWARN\033[0m: gradio.spaces not available, running without GPU optimization") | |
# Add parent directory to path to import preprocess functions | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
# Import functions from infer_watermelon.py and train_watermelon for the model | |
from train_watermelon import WatermelonModel | |
# Modified version of process_audio_data specifically for the app to handle various tensor shapes | |
def app_process_audio_data(waveform, sample_rate): | |
"""Modified version of process_audio_data for the app that handles different tensor dimensions""" | |
try: | |
print(f"\033[92mDEBUG\033[0m: Processing audio - Initial shape: {waveform.shape}, Sample rate: {sample_rate}") | |
# Handle different tensor dimensions | |
if waveform.dim() == 3: | |
print(f"\033[92mDEBUG\033[0m: Found 3D tensor, converting to 2D") | |
# For 3D tensor, take the first item (batch dimension) | |
waveform = waveform[0] | |
if waveform.dim() == 2: | |
# Use the first channel for stereo audio | |
waveform = waveform[0] | |
print(f"\033[92mDEBUG\033[0m: Using first channel, new shape: {waveform.shape}") | |
# Resample to 16kHz if needed | |
resample_rate = 16000 | |
if sample_rate != resample_rate: | |
print(f"\033[92mDEBUG\033[0m: Resampling from {sample_rate}Hz to {resample_rate}Hz") | |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=resample_rate)(waveform) | |
# Ensure 3 seconds of audio | |
if waveform.size(0) < 3 * resample_rate: | |
print(f"\033[92mDEBUG\033[0m: Padding audio from {waveform.size(0)} to {3 * resample_rate} samples") | |
waveform = torch.nn.functional.pad(waveform, (0, 3 * resample_rate - waveform.size(0))) | |
else: | |
print(f"\033[92mDEBUG\033[0m: Trimming audio from {waveform.size(0)} to {3 * resample_rate} samples") | |
waveform = waveform[: 3 * resample_rate] | |
# Apply MFCC transformation | |
print(f"\033[92mDEBUG\033[0m: Applying MFCC transformation") | |
mfcc_transform = torchaudio.transforms.MFCC( | |
sample_rate=resample_rate, | |
n_mfcc=13, | |
melkwargs={ | |
"n_fft": 256, | |
"win_length": 256, | |
"hop_length": 128, | |
"n_mels": 40, | |
} | |
) | |
mfcc = mfcc_transform(waveform) | |
print(f"\033[92mDEBUG\033[0m: MFCC output shape: {mfcc.shape}") | |
return mfcc | |
except Exception as e: | |
import traceback | |
print(f"\033[91mERR!\033[0m: Error in audio processing: {e}") | |
print(traceback.format_exc()) | |
return None | |
# Similarly for images, but let's import the original one | |
from preprocess import process_image_data | |
# Using the decorator directly on the function definition | |
def predict_sugar_content(audio, image, model_path): | |
"""Function with GPU acceleration to predict watermelon sugar content in Brix""" | |
try: | |
# Now check CUDA availability inside the GPU-decorated function | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
print(f"\033[92mINFO\033[0m: CUDA is available. Using device: {device}") | |
else: | |
device = torch.device("cpu") | |
print(f"\033[92mINFO\033[0m: CUDA is not available. Using device: {device}") | |
# Load model inside the function to ensure it's on the correct device | |
model = WatermelonModel().to(device) | |
model.load_state_dict(torch.load(model_path, map_location=device)) | |
model.eval() | |
print(f"\033[92mINFO\033[0m: Loaded model from {model_path}") | |
# Debug information about input types | |
print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}") | |
print(f"\033[92mDEBUG\033[0m: Audio input shape/length: {len(audio)}") | |
print(f"\033[92mDEBUG\033[0m: Image input type: {type(image)}") | |
if isinstance(image, np.ndarray): | |
print(f"\033[92mDEBUG\033[0m: Image input shape: {image.shape}") | |
# Handle different audio input formats | |
if isinstance(audio, tuple) and len(audio) == 2: | |
# Standard Gradio format: (sample_rate, audio_data) | |
sample_rate, audio_data = audio | |
print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}") | |
print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}") | |
elif isinstance(audio, tuple) and len(audio) > 2: | |
# Sometimes Gradio returns (sample_rate, audio_data, other_info...) | |
sample_rate, audio_data = audio[0], audio[-1] | |
print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}") | |
print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}") | |
elif isinstance(audio, str): | |
# Direct path to audio file | |
audio_data, sample_rate = torchaudio.load(audio) | |
print(f"\033[92mDEBUG\033[0m: Loaded audio from path with shape: {audio_data.shape}") | |
else: | |
return f"Error: Unsupported audio format. Got {type(audio)}" | |
# Create a temporary file path for the audio and image | |
temp_dir = "temp" | |
os.makedirs(temp_dir, exist_ok=True) | |
temp_audio_path = os.path.join(temp_dir, "temp_audio.wav") | |
temp_image_path = os.path.join(temp_dir, "temp_image.jpg") | |
# Import necessary libraries | |
from PIL import Image | |
# Audio handling - direct processing from the data in memory | |
if isinstance(audio_data, np.ndarray): | |
# Convert numpy array to tensor | |
print(f"\033[92mDEBUG\033[0m: Converting numpy audio with shape {audio_data.shape} to tensor") | |
audio_tensor = torch.tensor(audio_data).float() | |
# Handle different audio dimensions | |
if audio_data.ndim == 1: | |
# Single channel audio | |
audio_tensor = audio_tensor.unsqueeze(0) | |
elif audio_data.ndim == 2: | |
# Ensure channels are first dimension | |
if audio_data.shape[0] > audio_data.shape[1]: | |
# More rows than columns, probably (samples, channels) | |
audio_tensor = torch.tensor(audio_data.T).float() | |
else: | |
# Already a tensor | |
audio_tensor = audio_data.float() | |
print(f"\033[92mDEBUG\033[0m: Audio tensor shape before processing: {audio_tensor.shape}") | |
# Skip saving/loading and process directly | |
mfcc = app_process_audio_data(audio_tensor, sample_rate) | |
print(f"\033[92mDEBUG\033[0m: MFCC tensor shape after processing: {mfcc.shape if mfcc is not None else None}") | |
# Image handling | |
if isinstance(image, np.ndarray): | |
print(f"\033[92mDEBUG\033[0m: Converting numpy image with shape {image.shape} to PIL") | |
pil_image = Image.fromarray(image) | |
pil_image.save(temp_image_path) | |
print(f"\033[92mDEBUG\033[0m: Saved image to {temp_image_path}") | |
elif isinstance(image, str): | |
# If image is already a path | |
temp_image_path = image | |
print(f"\033[92mDEBUG\033[0m: Using provided image path: {temp_image_path}") | |
else: | |
return f"Error: Unsupported image format. Got {type(image)}" | |
# Process image | |
print(f"\033[92mDEBUG\033[0m: Loading and preprocessing image from {temp_image_path}") | |
image_tensor = torchvision.io.read_image(temp_image_path) | |
print(f"\033[92mDEBUG\033[0m: Loaded image shape: {image_tensor.shape}") | |
image_tensor = image_tensor.float() | |
processed_image = process_image_data(image_tensor) | |
print(f"\033[92mDEBUG\033[0m: Processed image shape: {processed_image.shape if processed_image is not None else None}") | |
# Add batch dimension for inference and move to device | |
if mfcc is not None: | |
mfcc = mfcc.unsqueeze(0).to(device) | |
print(f"\033[92mDEBUG\033[0m: Final MFCC shape with batch dimension: {mfcc.shape}") | |
if processed_image is not None: | |
processed_image = processed_image.unsqueeze(0).to(device) | |
print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}") | |
# Run inference | |
print(f"\033[92mDEBUG\033[0m: Running inference on device: {device}") | |
if mfcc is not None and processed_image is not None: | |
with torch.no_grad(): | |
brix_value = model(mfcc, processed_image) | |
print(f"\033[92mDEBUG\033[0m: Prediction successful: {brix_value.item()}") | |
else: | |
return "Error: Failed to process inputs. Please check the debug logs." | |
# Format the result with a range display | |
if brix_value is not None: | |
brix_score = brix_value.item() | |
# Create a header with the numerical result | |
result = f"π Predicted Sugar Content: {brix_score:.1f}Β° Brix π\n\n" | |
# Add Brix scale visualization | |
result += "Sugar Content Scale (in Β°Brix):\n" | |
result += "ββββββββββββββββββββββββββββββββββ\n" | |
# Create the scale display with Brix ranges | |
scale_ranges = [ | |
(0, 8, "Low Sugar (< 8Β° Brix)"), | |
(8, 9, "Mild Sweetness (8-9Β° Brix)"), | |
(9, 10, "Medium Sweetness (9-10Β° Brix)"), | |
(10, 11, "Sweet (10-11Β° Brix)"), | |
(11, 13, "Very Sweet (11-13Β° Brix)") | |
] | |
# Find which category the prediction falls into | |
user_category = None | |
for min_val, max_val, category_name in scale_ranges: | |
if min_val <= brix_score < max_val: | |
user_category = category_name | |
break | |
if brix_score >= scale_ranges[-1][0]: # Handle edge case | |
user_category = scale_ranges[-1][2] | |
# Display the scale with the user's result highlighted | |
for min_val, max_val, category_name in scale_ranges: | |
if category_name == user_category: | |
result += f"βΆ {min_val}-{max_val}: {category_name} β (YOUR WATERMELON)\n" | |
else: | |
result += f" {min_val}-{max_val}: {category_name}\n" | |
result += "ββββββββββββββββββββββββββββββββββ\n\n" | |
# Add assessment of the watermelon's sugar content | |
if brix_score < 8: | |
result += "Assessment: This watermelon has low sugar content. It may taste bland or slightly bitter." | |
elif brix_score < 9: | |
result += "Assessment: This watermelon has mild sweetness. Acceptable flavor but not very sweet." | |
elif brix_score < 10: | |
result += "Assessment: This watermelon has moderate sugar content. It should have pleasant sweetness." | |
elif brix_score < 11: | |
result += "Assessment: This watermelon has good sugar content! It should be sweet and juicy." | |
else: | |
result += "Assessment: This watermelon has excellent sugar content! Perfect choice for maximum sweetness and flavor." | |
return result | |
else: | |
return "Error: Could not predict sugar content. Please try again with different inputs." | |
except Exception as e: | |
import traceback | |
error_msg = f"Error: {str(e)}\n\n" | |
error_msg += traceback.format_exc() | |
print(f"\033[91mERR!\033[0m: {error_msg}") | |
return error_msg | |
print("\033[92mINFO\033[0m: GPU-accelerated prediction function created with @spaces.GPU decorator") | |
def create_app(model_path): | |
"""Create and launch the Gradio interface""" | |
# Define the prediction function with model path | |
def predict_fn(audio, image): | |
return predict_sugar_content(audio, image, model_path) | |
# Create Gradio interface | |
with gr.Blocks(title="Watermelon Sugar Content Predictor", theme=gr.themes.Soft()) as interface: | |
gr.Markdown("# π Watermelon Sugar Content Predictor") | |
gr.Markdown(""" | |
This app predicts the sugar content (in Β°Brix) of a watermelon based on its sound and appearance. | |
## Instructions: | |
1. Upload or record an audio of tapping the watermelon | |
2. Upload or capture an image of the watermelon | |
3. Click 'Predict' to get the sugar content estimation | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
audio_input = gr.Audio(label="Upload or Record Audio", type="numpy") | |
image_input = gr.Image(label="Upload or Capture Image") | |
submit_btn = gr.Button("Predict Sugar Content", variant="primary") | |
with gr.Column(): | |
output = gr.Textbox(label="Prediction Results", lines=12) | |
submit_btn.click( | |
fn=predict_fn, | |
inputs=[audio_input, image_input], | |
outputs=output | |
) | |
gr.Markdown(""" | |
## Tips for best results | |
- For audio: Tap the watermelon with your knuckle and record the sound | |
- For image: Take a clear photo of the whole watermelon in good lighting | |
## About Brix Measurement | |
Brix (Β°Bx) is a measurement of sugar content in a solution. For watermelons, higher Brix values indicate sweeter fruit. | |
The average ripe watermelon has a Brix value between 9-11Β°. | |
""") | |
return interface | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser(description="Watermelon Sugar Content Prediction App") | |
parser.add_argument( | |
"--model_path", | |
type=str, | |
default="models/watermelon_model_final.pt", | |
help="Path to the trained model file" | |
) | |
parser.add_argument( | |
"--share", | |
action="store_true", | |
help="Create a shareable link for the app" | |
) | |
parser.add_argument( | |
"--debug", | |
action="store_true", | |
help="Enable verbose debug output" | |
) | |
args = parser.parse_args() | |
if args.debug: | |
print(f"\033[92mINFO\033[0m: Debug mode enabled") | |
# Check if model exists | |
if not os.path.exists(args.model_path): | |
print(f"\033[91mERR!\033[0m: Model not found at {args.model_path}") | |
print("\033[92mINFO\033[0m: Please train a model first or provide a valid model path") | |
sys.exit(1) | |
# Create and launch the app | |
app = create_app(args.model_path) | |
app.launch(share=args.share) |