import os import sys import torch import torchaudio import torchvision import argparse import numpy as np # Add parent directory to path to import the preprocess functions sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from preprocess import process_audio_data, process_image_data # Import the model definition from train_watermelon import WatermelonModel def load_model(model_path): """Load a trained model from the given path""" device = torch.device( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) print(f"\033[92mINFO\033[0m: Using device: {device}") model = WatermelonModel().to(device) model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() print(f"\033[92mINFO\033[0m: Loaded model from {model_path}") return model, device def infer_single_sample(audio_path, image_path, model, device): """Run inference on a single sample""" # Load and process audio try: waveform, sample_rate = torchaudio.load(audio_path) mfcc = process_audio_data(waveform, sample_rate).to(device) # Load and process image image = torchvision.io.read_image(image_path) image = image.float() processed_image = process_image_data(image).to(device) # Add batch dimension mfcc = mfcc.unsqueeze(0) processed_image = processed_image.unsqueeze(0) # Run inference with torch.no_grad(): sweetness = model(mfcc, processed_image) return sweetness.item() except Exception as e: print(f"\033[91mERR!\033[0m: Error in inference: {e}") return None def infer_from_directory(data_dir, model_path, output_file=None, num_samples=None): """Run inference on samples from the dataset directory""" # Load model model, device = load_model(model_path) # Collect all samples samples = [] results = [] print(f"\033[92mINFO\033[0m: Reading samples from {data_dir}") # Walk through the directory structure for sweetness_dir in os.listdir(data_dir): try: sweetness = float(sweetness_dir) sweetness_path = os.path.join(data_dir, sweetness_dir) if os.path.isdir(sweetness_path): for id_dir in os.listdir(sweetness_path): id_path = os.path.join(sweetness_path, id_dir) if os.path.isdir(id_path): audio_file = os.path.join(id_path, f"{id_dir}.wav") image_file = os.path.join(id_path, f"{id_dir}.jpg") if os.path.exists(audio_file) and os.path.exists(image_file): samples.append((audio_file, image_file, sweetness, id_dir)) except ValueError: # Skip directories that are not valid sweetness values continue # Limit the number of samples if specified if num_samples is not None and num_samples > 0: samples = samples[:num_samples] print(f"\033[92mINFO\033[0m: Running inference on {len(samples)} samples") # Run inference on each sample for i, (audio_file, image_file, true_sweetness, sample_id) in enumerate(samples): print(f"\033[92mINFO\033[0m: Processing sample {i+1}/{len(samples)}: {sample_id}") predicted_sweetness = infer_single_sample(audio_file, image_file, model, device) if predicted_sweetness is not None: error = abs(predicted_sweetness - true_sweetness) results.append({ 'sample_id': sample_id, 'true_sweetness': true_sweetness, 'predicted_sweetness': predicted_sweetness, 'error': error }) print(f" Sample ID: {sample_id}") print(f" True sweetness: {true_sweetness:.2f}") print(f" Predicted sweetness: {predicted_sweetness:.2f}") print(f" Error: {error:.2f}") # Calculate mean absolute error if results: mae = np.mean([result['error'] for result in results]) print(f"\033[92mINFO\033[0m: Mean Absolute Error: {mae:.4f}") # Save results to file if specified if output_file and results: with open(output_file, 'w') as f: f.write("sample_id,true_sweetness,predicted_sweetness,error\n") for result in results: f.write(f"{result['sample_id']},{result['true_sweetness']:.2f},{result['predicted_sweetness']:.2f},{result['error']:.2f}\n") print(f"\033[92mINFO\033[0m: Results saved to {output_file}") return results def main(): parser = argparse.ArgumentParser(description="Watermelon Sweetness Inference") parser.add_argument("--model_path", type=str, required=True, help="Path to the trained model file") parser.add_argument("--data_dir", type=str, default="../cleaned", help="Path to the cleaned dataset directory") parser.add_argument("--output_file", type=str, help="Path to save inference results (CSV)") parser.add_argument("--num_samples", type=int, help="Number of samples to run inference on (default: all)") parser.add_argument("--audio_path", type=str, help="Path to a single audio file for inference") parser.add_argument("--image_path", type=str, help="Path to a single image file for inference") args = parser.parse_args() # Check if single sample inference or dataset inference if args.audio_path and args.image_path: # Single sample inference model, device = load_model(args.model_path) sweetness = infer_single_sample(args.audio_path, args.image_path, model, device) print(f"Predicted sweetness: {sweetness:.2f}") else: # Dataset inference infer_from_directory(args.data_dir, args.model_path, args.output_file, args.num_samples) if __name__ == "__main__": main()