Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import torch | |
import torchaudio | |
import torchvision | |
import argparse | |
import numpy as np | |
# Add parent directory to path to import the preprocess functions | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
from preprocess import process_audio_data, process_image_data | |
# Import the model definition | |
from train_watermelon import WatermelonModel | |
def load_model(model_path): | |
"""Load a trained model from the given path""" | |
device = torch.device( | |
"cuda" if torch.cuda.is_available() | |
else "mps" if torch.backends.mps.is_available() | |
else "cpu" | |
) | |
print(f"\033[92mINFO\033[0m: Using device: {device}") | |
model = WatermelonModel().to(device) | |
model.load_state_dict(torch.load(model_path, map_location=device)) | |
model.eval() | |
print(f"\033[92mINFO\033[0m: Loaded model from {model_path}") | |
return model, device | |
def infer_single_sample(audio_path, image_path, model, device): | |
"""Run inference on a single sample""" | |
# Load and process audio | |
try: | |
waveform, sample_rate = torchaudio.load(audio_path) | |
mfcc = process_audio_data(waveform, sample_rate).to(device) | |
# Load and process image | |
image = torchvision.io.read_image(image_path) | |
image = image.float() | |
processed_image = process_image_data(image).to(device) | |
# Add batch dimension | |
mfcc = mfcc.unsqueeze(0) | |
processed_image = processed_image.unsqueeze(0) | |
# Run inference | |
with torch.no_grad(): | |
sweetness = model(mfcc, processed_image) | |
return sweetness.item() | |
except Exception as e: | |
print(f"\033[91mERR!\033[0m: Error in inference: {e}") | |
return None | |
def infer_from_directory(data_dir, model_path, output_file=None, num_samples=None): | |
"""Run inference on samples from the dataset directory""" | |
# Load model | |
model, device = load_model(model_path) | |
# Collect all samples | |
samples = [] | |
results = [] | |
print(f"\033[92mINFO\033[0m: Reading samples from {data_dir}") | |
# Walk through the directory structure | |
for sweetness_dir in os.listdir(data_dir): | |
try: | |
sweetness = float(sweetness_dir) | |
sweetness_path = os.path.join(data_dir, sweetness_dir) | |
if os.path.isdir(sweetness_path): | |
for id_dir in os.listdir(sweetness_path): | |
id_path = os.path.join(sweetness_path, id_dir) | |
if os.path.isdir(id_path): | |
audio_file = os.path.join(id_path, f"{id_dir}.wav") | |
image_file = os.path.join(id_path, f"{id_dir}.jpg") | |
if os.path.exists(audio_file) and os.path.exists(image_file): | |
samples.append((audio_file, image_file, sweetness, id_dir)) | |
except ValueError: | |
# Skip directories that are not valid sweetness values | |
continue | |
# Limit the number of samples if specified | |
if num_samples is not None and num_samples > 0: | |
samples = samples[:num_samples] | |
print(f"\033[92mINFO\033[0m: Running inference on {len(samples)} samples") | |
# Run inference on each sample | |
for i, (audio_file, image_file, true_sweetness, sample_id) in enumerate(samples): | |
print(f"\033[92mINFO\033[0m: Processing sample {i+1}/{len(samples)}: {sample_id}") | |
predicted_sweetness = infer_single_sample(audio_file, image_file, model, device) | |
if predicted_sweetness is not None: | |
error = abs(predicted_sweetness - true_sweetness) | |
results.append({ | |
'sample_id': sample_id, | |
'true_sweetness': true_sweetness, | |
'predicted_sweetness': predicted_sweetness, | |
'error': error | |
}) | |
print(f" Sample ID: {sample_id}") | |
print(f" True sweetness: {true_sweetness:.2f}") | |
print(f" Predicted sweetness: {predicted_sweetness:.2f}") | |
print(f" Error: {error:.2f}") | |
# Calculate mean absolute error | |
if results: | |
mae = np.mean([result['error'] for result in results]) | |
print(f"\033[92mINFO\033[0m: Mean Absolute Error: {mae:.4f}") | |
# Save results to file if specified | |
if output_file and results: | |
with open(output_file, 'w') as f: | |
f.write("sample_id,true_sweetness,predicted_sweetness,error\n") | |
for result in results: | |
f.write(f"{result['sample_id']},{result['true_sweetness']:.2f},{result['predicted_sweetness']:.2f},{result['error']:.2f}\n") | |
print(f"\033[92mINFO\033[0m: Results saved to {output_file}") | |
return results | |
def main(): | |
parser = argparse.ArgumentParser(description="Watermelon Sweetness Inference") | |
parser.add_argument("--model_path", type=str, required=True, help="Path to the trained model file") | |
parser.add_argument("--data_dir", type=str, default="../cleaned", help="Path to the cleaned dataset directory") | |
parser.add_argument("--output_file", type=str, help="Path to save inference results (CSV)") | |
parser.add_argument("--num_samples", type=int, help="Number of samples to run inference on (default: all)") | |
parser.add_argument("--audio_path", type=str, help="Path to a single audio file for inference") | |
parser.add_argument("--image_path", type=str, help="Path to a single image file for inference") | |
args = parser.parse_args() | |
# Check if single sample inference or dataset inference | |
if args.audio_path and args.image_path: | |
# Single sample inference | |
model, device = load_model(args.model_path) | |
sweetness = infer_single_sample(args.audio_path, args.image_path, model, device) | |
print(f"Predicted sweetness: {sweetness:.2f}") | |
else: | |
# Dataset inference | |
infer_from_directory(args.data_dir, args.model_path, args.output_file, args.num_samples) | |
if __name__ == "__main__": | |
main() |