watermelon2 / infer_watermelon.py
Xalphinions's picture
Upload folder using huggingface_hub
5900417 verified
raw
history blame
6.08 kB
import os
import sys
import torch
import torchaudio
import torchvision
import argparse
import numpy as np
# Add parent directory to path to import the preprocess functions
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from preprocess import process_audio_data, process_image_data
# Import the model definition
from train_watermelon import WatermelonModel
def load_model(model_path):
"""Load a trained model from the given path"""
device = torch.device(
"cuda" if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu"
)
print(f"\033[92mINFO\033[0m: Using device: {device}")
model = WatermelonModel().to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
print(f"\033[92mINFO\033[0m: Loaded model from {model_path}")
return model, device
def infer_single_sample(audio_path, image_path, model, device):
"""Run inference on a single sample"""
# Load and process audio
try:
waveform, sample_rate = torchaudio.load(audio_path)
mfcc = process_audio_data(waveform, sample_rate).to(device)
# Load and process image
image = torchvision.io.read_image(image_path)
image = image.float()
processed_image = process_image_data(image).to(device)
# Add batch dimension
mfcc = mfcc.unsqueeze(0)
processed_image = processed_image.unsqueeze(0)
# Run inference
with torch.no_grad():
sweetness = model(mfcc, processed_image)
return sweetness.item()
except Exception as e:
print(f"\033[91mERR!\033[0m: Error in inference: {e}")
return None
def infer_from_directory(data_dir, model_path, output_file=None, num_samples=None):
"""Run inference on samples from the dataset directory"""
# Load model
model, device = load_model(model_path)
# Collect all samples
samples = []
results = []
print(f"\033[92mINFO\033[0m: Reading samples from {data_dir}")
# Walk through the directory structure
for sweetness_dir in os.listdir(data_dir):
try:
sweetness = float(sweetness_dir)
sweetness_path = os.path.join(data_dir, sweetness_dir)
if os.path.isdir(sweetness_path):
for id_dir in os.listdir(sweetness_path):
id_path = os.path.join(sweetness_path, id_dir)
if os.path.isdir(id_path):
audio_file = os.path.join(id_path, f"{id_dir}.wav")
image_file = os.path.join(id_path, f"{id_dir}.jpg")
if os.path.exists(audio_file) and os.path.exists(image_file):
samples.append((audio_file, image_file, sweetness, id_dir))
except ValueError:
# Skip directories that are not valid sweetness values
continue
# Limit the number of samples if specified
if num_samples is not None and num_samples > 0:
samples = samples[:num_samples]
print(f"\033[92mINFO\033[0m: Running inference on {len(samples)} samples")
# Run inference on each sample
for i, (audio_file, image_file, true_sweetness, sample_id) in enumerate(samples):
print(f"\033[92mINFO\033[0m: Processing sample {i+1}/{len(samples)}: {sample_id}")
predicted_sweetness = infer_single_sample(audio_file, image_file, model, device)
if predicted_sweetness is not None:
error = abs(predicted_sweetness - true_sweetness)
results.append({
'sample_id': sample_id,
'true_sweetness': true_sweetness,
'predicted_sweetness': predicted_sweetness,
'error': error
})
print(f" Sample ID: {sample_id}")
print(f" True sweetness: {true_sweetness:.2f}")
print(f" Predicted sweetness: {predicted_sweetness:.2f}")
print(f" Error: {error:.2f}")
# Calculate mean absolute error
if results:
mae = np.mean([result['error'] for result in results])
print(f"\033[92mINFO\033[0m: Mean Absolute Error: {mae:.4f}")
# Save results to file if specified
if output_file and results:
with open(output_file, 'w') as f:
f.write("sample_id,true_sweetness,predicted_sweetness,error\n")
for result in results:
f.write(f"{result['sample_id']},{result['true_sweetness']:.2f},{result['predicted_sweetness']:.2f},{result['error']:.2f}\n")
print(f"\033[92mINFO\033[0m: Results saved to {output_file}")
return results
def main():
parser = argparse.ArgumentParser(description="Watermelon Sweetness Inference")
parser.add_argument("--model_path", type=str, required=True, help="Path to the trained model file")
parser.add_argument("--data_dir", type=str, default="../cleaned", help="Path to the cleaned dataset directory")
parser.add_argument("--output_file", type=str, help="Path to save inference results (CSV)")
parser.add_argument("--num_samples", type=int, help="Number of samples to run inference on (default: all)")
parser.add_argument("--audio_path", type=str, help="Path to a single audio file for inference")
parser.add_argument("--image_path", type=str, help="Path to a single image file for inference")
args = parser.parse_args()
# Check if single sample inference or dataset inference
if args.audio_path and args.image_path:
# Single sample inference
model, device = load_model(args.model_path)
sweetness = infer_single_sample(args.audio_path, args.image_path, model, device)
print(f"Predicted sweetness: {sweetness:.2f}")
else:
# Dataset inference
infer_from_directory(args.data_dir, args.model_path, args.output_file, args.num_samples)
if __name__ == "__main__":
main()