Spaces:
Sleeping
Sleeping
File size: 6,082 Bytes
5900417 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import os
import sys
import torch
import torchaudio
import torchvision
import argparse
import numpy as np
# Add parent directory to path to import the preprocess functions
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from preprocess import process_audio_data, process_image_data
# Import the model definition
from train_watermelon import WatermelonModel
def load_model(model_path):
"""Load a trained model from the given path"""
device = torch.device(
"cuda" if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu"
)
print(f"\033[92mINFO\033[0m: Using device: {device}")
model = WatermelonModel().to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
print(f"\033[92mINFO\033[0m: Loaded model from {model_path}")
return model, device
def infer_single_sample(audio_path, image_path, model, device):
"""Run inference on a single sample"""
# Load and process audio
try:
waveform, sample_rate = torchaudio.load(audio_path)
mfcc = process_audio_data(waveform, sample_rate).to(device)
# Load and process image
image = torchvision.io.read_image(image_path)
image = image.float()
processed_image = process_image_data(image).to(device)
# Add batch dimension
mfcc = mfcc.unsqueeze(0)
processed_image = processed_image.unsqueeze(0)
# Run inference
with torch.no_grad():
sweetness = model(mfcc, processed_image)
return sweetness.item()
except Exception as e:
print(f"\033[91mERR!\033[0m: Error in inference: {e}")
return None
def infer_from_directory(data_dir, model_path, output_file=None, num_samples=None):
"""Run inference on samples from the dataset directory"""
# Load model
model, device = load_model(model_path)
# Collect all samples
samples = []
results = []
print(f"\033[92mINFO\033[0m: Reading samples from {data_dir}")
# Walk through the directory structure
for sweetness_dir in os.listdir(data_dir):
try:
sweetness = float(sweetness_dir)
sweetness_path = os.path.join(data_dir, sweetness_dir)
if os.path.isdir(sweetness_path):
for id_dir in os.listdir(sweetness_path):
id_path = os.path.join(sweetness_path, id_dir)
if os.path.isdir(id_path):
audio_file = os.path.join(id_path, f"{id_dir}.wav")
image_file = os.path.join(id_path, f"{id_dir}.jpg")
if os.path.exists(audio_file) and os.path.exists(image_file):
samples.append((audio_file, image_file, sweetness, id_dir))
except ValueError:
# Skip directories that are not valid sweetness values
continue
# Limit the number of samples if specified
if num_samples is not None and num_samples > 0:
samples = samples[:num_samples]
print(f"\033[92mINFO\033[0m: Running inference on {len(samples)} samples")
# Run inference on each sample
for i, (audio_file, image_file, true_sweetness, sample_id) in enumerate(samples):
print(f"\033[92mINFO\033[0m: Processing sample {i+1}/{len(samples)}: {sample_id}")
predicted_sweetness = infer_single_sample(audio_file, image_file, model, device)
if predicted_sweetness is not None:
error = abs(predicted_sweetness - true_sweetness)
results.append({
'sample_id': sample_id,
'true_sweetness': true_sweetness,
'predicted_sweetness': predicted_sweetness,
'error': error
})
print(f" Sample ID: {sample_id}")
print(f" True sweetness: {true_sweetness:.2f}")
print(f" Predicted sweetness: {predicted_sweetness:.2f}")
print(f" Error: {error:.2f}")
# Calculate mean absolute error
if results:
mae = np.mean([result['error'] for result in results])
print(f"\033[92mINFO\033[0m: Mean Absolute Error: {mae:.4f}")
# Save results to file if specified
if output_file and results:
with open(output_file, 'w') as f:
f.write("sample_id,true_sweetness,predicted_sweetness,error\n")
for result in results:
f.write(f"{result['sample_id']},{result['true_sweetness']:.2f},{result['predicted_sweetness']:.2f},{result['error']:.2f}\n")
print(f"\033[92mINFO\033[0m: Results saved to {output_file}")
return results
def main():
parser = argparse.ArgumentParser(description="Watermelon Sweetness Inference")
parser.add_argument("--model_path", type=str, required=True, help="Path to the trained model file")
parser.add_argument("--data_dir", type=str, default="../cleaned", help="Path to the cleaned dataset directory")
parser.add_argument("--output_file", type=str, help="Path to save inference results (CSV)")
parser.add_argument("--num_samples", type=int, help="Number of samples to run inference on (default: all)")
parser.add_argument("--audio_path", type=str, help="Path to a single audio file for inference")
parser.add_argument("--image_path", type=str, help="Path to a single image file for inference")
args = parser.parse_args()
# Check if single sample inference or dataset inference
if args.audio_path and args.image_path:
# Single sample inference
model, device = load_model(args.model_path)
sweetness = infer_single_sample(args.audio_path, args.image_path, model, device)
print(f"Predicted sweetness: {sweetness:.2f}")
else:
# Dataset inference
infer_from_directory(args.data_dir, args.model_path, args.output_file, args.num_samples)
if __name__ == "__main__":
main() |