watermelon2 / app_local_backup.py
Xalphinions's picture
Upload folder using huggingface_hub
6f4e394 verified
import os
import sys
import torch
import numpy as np
import gradio as gr
import torchaudio
import torchvision
import spaces
# # Import Gradio Spaces GPU decorator
# try:
# from gradio import spaces
# HAS_SPACES = True
# print("\033[92mINFO\033[0m: Gradio Spaces detected, GPU acceleration will be enabled")
# except ImportError:
# HAS_SPACES = False
# print("\033[93mWARN\033[0m: gradio.spaces not available, running without GPU optimization")
# Add parent directory to path to import preprocess functions
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Import functions from infer_watermelon.py and train_watermelon for the model
from train_watermelon import WatermelonModel
# Modified version of process_audio_data specifically for the app to handle various tensor shapes
def app_process_audio_data(waveform, sample_rate):
"""Modified version of process_audio_data for the app that handles different tensor dimensions"""
try:
print(f"\033[92mDEBUG\033[0m: Processing audio - Initial shape: {waveform.shape}, Sample rate: {sample_rate}")
# Handle different tensor dimensions
if waveform.dim() == 3:
print(f"\033[92mDEBUG\033[0m: Found 3D tensor, converting to 2D")
# For 3D tensor, take the first item (batch dimension)
waveform = waveform[0]
if waveform.dim() == 2:
# Use the first channel for stereo audio
waveform = waveform[0]
print(f"\033[92mDEBUG\033[0m: Using first channel, new shape: {waveform.shape}")
# Resample to 16kHz if needed
resample_rate = 16000
if sample_rate != resample_rate:
print(f"\033[92mDEBUG\033[0m: Resampling from {sample_rate}Hz to {resample_rate}Hz")
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=resample_rate)(waveform)
# Ensure 3 seconds of audio
if waveform.size(0) < 3 * resample_rate:
print(f"\033[92mDEBUG\033[0m: Padding audio from {waveform.size(0)} to {3 * resample_rate} samples")
waveform = torch.nn.functional.pad(waveform, (0, 3 * resample_rate - waveform.size(0)))
else:
print(f"\033[92mDEBUG\033[0m: Trimming audio from {waveform.size(0)} to {3 * resample_rate} samples")
waveform = waveform[: 3 * resample_rate]
# Apply MFCC transformation
print(f"\033[92mDEBUG\033[0m: Applying MFCC transformation")
mfcc_transform = torchaudio.transforms.MFCC(
sample_rate=resample_rate,
n_mfcc=13,
melkwargs={
"n_fft": 256,
"win_length": 256,
"hop_length": 128,
"n_mels": 40,
}
)
mfcc = mfcc_transform(waveform)
print(f"\033[92mDEBUG\033[0m: MFCC output shape: {mfcc.shape}")
return mfcc
except Exception as e:
import traceback
print(f"\033[91mERR!\033[0m: Error in audio processing: {e}")
print(traceback.format_exc())
return None
# Similarly for images, but let's import the original one
from preprocess import process_image_data
# Using the decorator directly on the function definition
@spaces.GPU
def predict_sugar_content(audio, image, model_path):
"""Function with GPU acceleration to predict watermelon sugar content in Brix"""
try:
# Now check CUDA availability inside the GPU-decorated function
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"\033[92mINFO\033[0m: CUDA is available. Using device: {device}")
else:
device = torch.device("cpu")
print(f"\033[92mINFO\033[0m: CUDA is not available. Using device: {device}")
# Load model inside the function to ensure it's on the correct device
model = WatermelonModel().to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
print(f"\033[92mINFO\033[0m: Loaded model from {model_path}")
# Debug information about input types
print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
print(f"\033[92mDEBUG\033[0m: Audio input shape/length: {len(audio)}")
print(f"\033[92mDEBUG\033[0m: Image input type: {type(image)}")
if isinstance(image, np.ndarray):
print(f"\033[92mDEBUG\033[0m: Image input shape: {image.shape}")
# Handle different audio input formats
if isinstance(audio, tuple) and len(audio) == 2:
# Standard Gradio format: (sample_rate, audio_data)
sample_rate, audio_data = audio
print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}")
print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
elif isinstance(audio, tuple) and len(audio) > 2:
# Sometimes Gradio returns (sample_rate, audio_data, other_info...)
sample_rate, audio_data = audio[0], audio[-1]
print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}")
print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
elif isinstance(audio, str):
# Direct path to audio file
audio_data, sample_rate = torchaudio.load(audio)
print(f"\033[92mDEBUG\033[0m: Loaded audio from path with shape: {audio_data.shape}")
else:
return f"Error: Unsupported audio format. Got {type(audio)}"
# Create a temporary file path for the audio and image
temp_dir = "temp"
os.makedirs(temp_dir, exist_ok=True)
temp_audio_path = os.path.join(temp_dir, "temp_audio.wav")
temp_image_path = os.path.join(temp_dir, "temp_image.jpg")
# Import necessary libraries
from PIL import Image
# Audio handling - direct processing from the data in memory
if isinstance(audio_data, np.ndarray):
# Convert numpy array to tensor
print(f"\033[92mDEBUG\033[0m: Converting numpy audio with shape {audio_data.shape} to tensor")
audio_tensor = torch.tensor(audio_data).float()
# Handle different audio dimensions
if audio_data.ndim == 1:
# Single channel audio
audio_tensor = audio_tensor.unsqueeze(0)
elif audio_data.ndim == 2:
# Ensure channels are first dimension
if audio_data.shape[0] > audio_data.shape[1]:
# More rows than columns, probably (samples, channels)
audio_tensor = torch.tensor(audio_data.T).float()
else:
# Already a tensor
audio_tensor = audio_data.float()
print(f"\033[92mDEBUG\033[0m: Audio tensor shape before processing: {audio_tensor.shape}")
# Skip saving/loading and process directly
mfcc = app_process_audio_data(audio_tensor, sample_rate)
print(f"\033[92mDEBUG\033[0m: MFCC tensor shape after processing: {mfcc.shape if mfcc is not None else None}")
# Image handling
if isinstance(image, np.ndarray):
print(f"\033[92mDEBUG\033[0m: Converting numpy image with shape {image.shape} to PIL")
pil_image = Image.fromarray(image)
pil_image.save(temp_image_path)
print(f"\033[92mDEBUG\033[0m: Saved image to {temp_image_path}")
elif isinstance(image, str):
# If image is already a path
temp_image_path = image
print(f"\033[92mDEBUG\033[0m: Using provided image path: {temp_image_path}")
else:
return f"Error: Unsupported image format. Got {type(image)}"
# Process image
print(f"\033[92mDEBUG\033[0m: Loading and preprocessing image from {temp_image_path}")
image_tensor = torchvision.io.read_image(temp_image_path)
print(f"\033[92mDEBUG\033[0m: Loaded image shape: {image_tensor.shape}")
image_tensor = image_tensor.float()
processed_image = process_image_data(image_tensor)
print(f"\033[92mDEBUG\033[0m: Processed image shape: {processed_image.shape if processed_image is not None else None}")
# Add batch dimension for inference and move to device
if mfcc is not None:
mfcc = mfcc.unsqueeze(0).to(device)
print(f"\033[92mDEBUG\033[0m: Final MFCC shape with batch dimension: {mfcc.shape}")
if processed_image is not None:
processed_image = processed_image.unsqueeze(0).to(device)
print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}")
# Run inference
print(f"\033[92mDEBUG\033[0m: Running inference on device: {device}")
if mfcc is not None and processed_image is not None:
with torch.no_grad():
brix_value = model(mfcc, processed_image)
print(f"\033[92mDEBUG\033[0m: Prediction successful: {brix_value.item()}")
else:
return "Error: Failed to process inputs. Please check the debug logs."
# Format the result with a range display
if brix_value is not None:
brix_score = brix_value.item()
# Create a header with the numerical result
result = f"πŸ‰ Predicted Sugar Content: {brix_score:.1f}Β° Brix πŸ‰\n\n"
# Add Brix scale visualization
result += "Sugar Content Scale (in Β°Brix):\n"
result += "──────────────────────────────────\n"
# Create the scale display with Brix ranges
scale_ranges = [
(0, 8, "Low Sugar (< 8Β° Brix)"),
(8, 9, "Mild Sweetness (8-9Β° Brix)"),
(9, 10, "Medium Sweetness (9-10Β° Brix)"),
(10, 11, "Sweet (10-11Β° Brix)"),
(11, 13, "Very Sweet (11-13Β° Brix)")
]
# Find which category the prediction falls into
user_category = None
for min_val, max_val, category_name in scale_ranges:
if min_val <= brix_score < max_val:
user_category = category_name
break
if brix_score >= scale_ranges[-1][0]: # Handle edge case
user_category = scale_ranges[-1][2]
# Display the scale with the user's result highlighted
for min_val, max_val, category_name in scale_ranges:
if category_name == user_category:
result += f"β–Ά {min_val}-{max_val}: {category_name} β—€ (YOUR WATERMELON)\n"
else:
result += f" {min_val}-{max_val}: {category_name}\n"
result += "──────────────────────────────────\n\n"
# Add assessment of the watermelon's sugar content
if brix_score < 8:
result += "Assessment: This watermelon has low sugar content. It may taste bland or slightly bitter."
elif brix_score < 9:
result += "Assessment: This watermelon has mild sweetness. Acceptable flavor but not very sweet."
elif brix_score < 10:
result += "Assessment: This watermelon has moderate sugar content. It should have pleasant sweetness."
elif brix_score < 11:
result += "Assessment: This watermelon has good sugar content! It should be sweet and juicy."
else:
result += "Assessment: This watermelon has excellent sugar content! Perfect choice for maximum sweetness and flavor."
return result
else:
return "Error: Could not predict sugar content. Please try again with different inputs."
except Exception as e:
import traceback
error_msg = f"Error: {str(e)}\n\n"
error_msg += traceback.format_exc()
print(f"\033[91mERR!\033[0m: {error_msg}")
return error_msg
print("\033[92mINFO\033[0m: GPU-accelerated prediction function created with @spaces.GPU decorator")
def create_app(model_path):
"""Create and launch the Gradio interface"""
# Define the prediction function with model path
def predict_fn(audio, image):
return predict_sugar_content(audio, image, model_path)
# Create Gradio interface
with gr.Blocks(title="Watermelon Sugar Content Predictor", theme=gr.themes.Soft()) as interface:
gr.Markdown("# πŸ‰ Watermelon Sugar Content Predictor")
gr.Markdown("""
This app predicts the sugar content (in Β°Brix) of a watermelon based on its sound and appearance.
## Instructions:
1. Upload or record an audio of tapping the watermelon
2. Upload or capture an image of the watermelon
3. Click 'Predict' to get the sugar content estimation
""")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(label="Upload or Record Audio", type="numpy")
image_input = gr.Image(label="Upload or Capture Image")
submit_btn = gr.Button("Predict Sugar Content", variant="primary")
with gr.Column():
output = gr.Textbox(label="Prediction Results", lines=12)
submit_btn.click(
fn=predict_fn,
inputs=[audio_input, image_input],
outputs=output
)
gr.Markdown("""
## Tips for best results
- For audio: Tap the watermelon with your knuckle and record the sound
- For image: Take a clear photo of the whole watermelon in good lighting
## About Brix Measurement
Brix (Β°Bx) is a measurement of sugar content in a solution. For watermelons, higher Brix values indicate sweeter fruit.
The average ripe watermelon has a Brix value between 9-11Β°.
""")
return interface
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Watermelon Sugar Content Prediction App")
parser.add_argument(
"--model_path",
type=str,
default="models/watermelon_model_final.pt",
help="Path to the trained model file"
)
parser.add_argument(
"--share",
action="store_true",
help="Create a shareable link for the app"
)
parser.add_argument(
"--debug",
action="store_true",
help="Enable verbose debug output"
)
args = parser.parse_args()
if args.debug:
print(f"\033[92mINFO\033[0m: Debug mode enabled")
# Check if model exists
if not os.path.exists(args.model_path):
print(f"\033[91mERR!\033[0m: Model not found at {args.model_path}")
print("\033[92mINFO\033[0m: Please train a model first or provide a valid model path")
sys.exit(1)
# Create and launch the app
app = create_app(args.model_path)
app.launch(share=args.share)