File size: 6,082 Bytes
5900417
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import sys
import torch
import torchaudio
import torchvision
import argparse
import numpy as np

# Add parent directory to path to import the preprocess functions
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from preprocess import process_audio_data, process_image_data

# Import the model definition
from train_watermelon import WatermelonModel

def load_model(model_path):
    """Load a trained model from the given path"""
    device = torch.device(
        "cuda" if torch.cuda.is_available() 
        else "mps" if torch.backends.mps.is_available() 
        else "cpu"
    )
    print(f"\033[92mINFO\033[0m: Using device: {device}")
    
    model = WatermelonModel().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    print(f"\033[92mINFO\033[0m: Loaded model from {model_path}")
    
    return model, device

def infer_single_sample(audio_path, image_path, model, device):
    """Run inference on a single sample"""
    # Load and process audio
    try:
        waveform, sample_rate = torchaudio.load(audio_path)
        mfcc = process_audio_data(waveform, sample_rate).to(device)
        
        # Load and process image
        image = torchvision.io.read_image(image_path)
        image = image.float()
        processed_image = process_image_data(image).to(device)
        
        # Add batch dimension
        mfcc = mfcc.unsqueeze(0)
        processed_image = processed_image.unsqueeze(0)
        
        # Run inference
        with torch.no_grad():
            sweetness = model(mfcc, processed_image)
        
        return sweetness.item()
    except Exception as e:
        print(f"\033[91mERR!\033[0m: Error in inference: {e}")
        return None

def infer_from_directory(data_dir, model_path, output_file=None, num_samples=None):
    """Run inference on samples from the dataset directory"""
    # Load model
    model, device = load_model(model_path)
    
    # Collect all samples
    samples = []
    results = []
    
    print(f"\033[92mINFO\033[0m: Reading samples from {data_dir}")
    
    # Walk through the directory structure
    for sweetness_dir in os.listdir(data_dir):
        try:
            sweetness = float(sweetness_dir)
            sweetness_path = os.path.join(data_dir, sweetness_dir)
            
            if os.path.isdir(sweetness_path):
                for id_dir in os.listdir(sweetness_path):
                    id_path = os.path.join(sweetness_path, id_dir)
                    
                    if os.path.isdir(id_path):
                        audio_file = os.path.join(id_path, f"{id_dir}.wav")
                        image_file = os.path.join(id_path, f"{id_dir}.jpg")
                        
                        if os.path.exists(audio_file) and os.path.exists(image_file):
                            samples.append((audio_file, image_file, sweetness, id_dir))
        except ValueError:
            # Skip directories that are not valid sweetness values
            continue
    
    # Limit the number of samples if specified
    if num_samples is not None and num_samples > 0:
        samples = samples[:num_samples]
    
    print(f"\033[92mINFO\033[0m: Running inference on {len(samples)} samples")
    
    # Run inference on each sample
    for i, (audio_file, image_file, true_sweetness, sample_id) in enumerate(samples):
        print(f"\033[92mINFO\033[0m: Processing sample {i+1}/{len(samples)}: {sample_id}")
        
        predicted_sweetness = infer_single_sample(audio_file, image_file, model, device)
        
        if predicted_sweetness is not None:
            error = abs(predicted_sweetness - true_sweetness)
            results.append({
                'sample_id': sample_id,
                'true_sweetness': true_sweetness,
                'predicted_sweetness': predicted_sweetness,
                'error': error
            })
            print(f"  Sample ID: {sample_id}")
            print(f"  True sweetness: {true_sweetness:.2f}")
            print(f"  Predicted sweetness: {predicted_sweetness:.2f}")
            print(f"  Error: {error:.2f}")
    
    # Calculate mean absolute error
    if results:
        mae = np.mean([result['error'] for result in results])
        print(f"\033[92mINFO\033[0m: Mean Absolute Error: {mae:.4f}")
    
    # Save results to file if specified
    if output_file and results:
        with open(output_file, 'w') as f:
            f.write("sample_id,true_sweetness,predicted_sweetness,error\n")
            for result in results:
                f.write(f"{result['sample_id']},{result['true_sweetness']:.2f},{result['predicted_sweetness']:.2f},{result['error']:.2f}\n")
        print(f"\033[92mINFO\033[0m: Results saved to {output_file}")
    
    return results

def main():
    parser = argparse.ArgumentParser(description="Watermelon Sweetness Inference")
    parser.add_argument("--model_path", type=str, required=True, help="Path to the trained model file")
    parser.add_argument("--data_dir", type=str, default="../cleaned", help="Path to the cleaned dataset directory")
    parser.add_argument("--output_file", type=str, help="Path to save inference results (CSV)")
    parser.add_argument("--num_samples", type=int, help="Number of samples to run inference on (default: all)")
    parser.add_argument("--audio_path", type=str, help="Path to a single audio file for inference")
    parser.add_argument("--image_path", type=str, help="Path to a single image file for inference")
    
    args = parser.parse_args()
    
    # Check if single sample inference or dataset inference
    if args.audio_path and args.image_path:
        # Single sample inference
        model, device = load_model(args.model_path)
        sweetness = infer_single_sample(args.audio_path, args.image_path, model, device)
        print(f"Predicted sweetness: {sweetness:.2f}")
    else:
        # Dataset inference
        infer_from_directory(args.data_dir, args.model_path, args.output_file, args.num_samples)

if __name__ == "__main__":
    main()