import os import torch import torchaudio import torchvision import numpy as np import json from torch.utils.data import Dataset, DataLoader import sys from tqdm import tqdm # Add parent directory to path to import the preprocess functions sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from preprocess import process_audio_data, process_image_data # Import the WatermelonDataset and WatermelonModelModular from the evaluate_backbones.py file from evaluate_backbones import WatermelonDataset, WatermelonModelModular, IMAGE_BACKBONES, AUDIO_BACKBONES # Print library versions print(f"\033[92mINFO\033[0m: PyTorch version: {torch.__version__}") print(f"\033[92mINFO\033[0m: Torchaudio version: {torchaudio.__version__}") print(f"\033[92mINFO\033[0m: Torchvision version: {torchvision.__version__}") # Device selection device = torch.device( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) print(f"\033[92mINFO\033[0m: Using device: {device}") # Define the top-performing models based on the previous evaluation TOP_MODELS = [ {"image_backbone": "efficientnet_b3", "audio_backbone": "transformer"}, {"image_backbone": "efficientnet_b0", "audio_backbone": "transformer"}, {"image_backbone": "resnet50", "audio_backbone": "transformer"} ] # Define class for the MoE model class WatermelonMoEModel(torch.nn.Module): def __init__(self, model_configs, model_dir="test_models", weights=None): """ Mixture of Experts model that combines multiple backbone models. Args: model_configs: List of dictionaries with 'image_backbone' and 'audio_backbone' keys model_dir: Directory where model checkpoints are stored weights: Optional list of weights for each model (None for equal weighting) """ super(WatermelonMoEModel, self).__init__() self.models = [] self.model_configs = model_configs # Load each model for config in model_configs: img_backbone = config["image_backbone"] audio_backbone = config["audio_backbone"] # Initialize model model = WatermelonModelModular(img_backbone, audio_backbone) # Load weights model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt") if os.path.exists(model_path): print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}") model.load_state_dict(torch.load(model_path, map_location=device)) else: print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}") continue model.to(device) model.eval() # Set to evaluation mode self.models.append(model) # Set model weights (uniform by default) if weights: assert len(weights) == len(self.models), "Number of weights must match number of models" self.weights = weights else: self.weights = [1.0 / len(self.models)] * len(self.models) print(f"\033[92mINFO\033[0m: Loaded {len(self.models)} models for MoE ensemble") print(f"\033[92mINFO\033[0m: Model weights: {self.weights}") def forward(self, mfcc, image): """ Forward pass through the MoE model. Returns the weighted average of all model outputs. """ outputs = [] # Get outputs from each model with torch.no_grad(): for i, model in enumerate(self.models): output = model(mfcc, image) outputs.append(output * self.weights[i]) # Return weighted average return torch.sum(torch.stack(outputs), dim=0) def evaluate_moe_model(data_dir, model_dir="test_models", weights=None): """ Evaluate the MoE model on the test set. """ # Load dataset print(f"\033[92mINFO\033[0m: Loading dataset from {data_dir}") dataset = WatermelonDataset(data_dir) n_samples = len(dataset) # Split dataset train_size = int(0.7 * n_samples) val_size = int(0.2 * n_samples) test_size = n_samples - train_size - val_size _, _, test_dataset = torch.utils.data.random_split( dataset, [train_size, val_size, test_size] ) # Use a reasonable batch size batch_size = 8 test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # Initialize MoE model moe_model = WatermelonMoEModel(TOP_MODELS, model_dir, weights) moe_model.eval() # Evaluation metrics mae_criterion = torch.nn.L1Loss() mse_criterion = torch.nn.MSELoss() test_mae = 0.0 test_mse = 0.0 print(f"\033[92mINFO\033[0m: Evaluating MoE model on {len(test_dataset)} test samples") # Individual model predictions for analysis individual_predictions = {f"{config['image_backbone']}_{config['audio_backbone']}": [] for config in TOP_MODELS} true_labels = [] moe_predictions = [] # Evaluation loop test_iterator = tqdm(test_loader, desc="Testing MoE") with torch.no_grad(): for i, (mfcc, image, label) in enumerate(test_iterator): try: mfcc, image, label = mfcc.to(device), image.to(device), label.to(device) # Store individual model outputs for analysis for j, model in enumerate(moe_model.models): config = TOP_MODELS[j] model_name = f"{config['image_backbone']}_{config['audio_backbone']}" output = model(mfcc, image) individual_predictions[model_name].extend(output.view(-1).cpu().numpy()) # Get MoE prediction output = moe_model(mfcc, image) moe_predictions.extend(output.view(-1).cpu().numpy()) # Store true labels label = label.view(-1, 1).float() true_labels.extend(label.view(-1).cpu().numpy()) # Calculate metrics mae = mae_criterion(output, label) mse = mse_criterion(output, label) test_mae += mae.item() test_mse += mse.item() test_iterator.set_postfix({"MAE": f"{mae.item():.4f}", "MSE": f"{mse.item():.4f}"}) # Clean up memory if device.type == 'cuda': del mfcc, image, label, output, mae, mse torch.cuda.empty_cache() except Exception as e: print(f"\033[91mERR!\033[0m: Error in test batch {i}: {e}") if device.type == 'cuda': torch.cuda.empty_cache() continue # Calculate average metrics avg_test_mae = test_mae / len(test_loader) if len(test_loader) > 0 else float('inf') avg_test_mse = test_mse / len(test_loader) if len(test_loader) > 0 else float('inf') print(f"\n\033[92mINFO\033[0m: === MoE Model Results ===") print(f"Test MAE: {avg_test_mae:.4f}") print(f"Test MSE: {avg_test_mse:.4f}") # Compare with individual models print(f"\n\033[92mINFO\033[0m: === Comparison with Individual Models ===") print(f"{'Model':<30} {'Test MAE':<15}") print("="*45) # Load previous results results_file = "backbone_evaluation_results.json" if os.path.exists(results_file): with open(results_file, 'r') as f: previous_results = json.load(f) # Filter results for our top models for config in TOP_MODELS: img_backbone = config["image_backbone"] audio_backbone = config["audio_backbone"] for result in previous_results: if result["image_backbone"] == img_backbone and result["audio_backbone"] == audio_backbone: print(f"{img_backbone}_{audio_backbone:<20} {result['test_mae']:<15.4f}") print(f"MoE (Ensemble) {avg_test_mae:<15.4f}") # Save results and predictions results = { "moe_test_mae": float(avg_test_mae), "moe_test_mse": float(avg_test_mse), "true_labels": [float(x) for x in true_labels], "moe_predictions": [float(x) for x in moe_predictions], "individual_predictions": {key: [float(x) for x in values] for key, values in individual_predictions.items()} } with open("moe_evaluation_results.json", 'w') as f: json.dump(results, f, indent=4) print(f"\033[92mINFO\033[0m: Results saved to moe_evaluation_results.json") return avg_test_mae, avg_test_mse if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Test Mixture of Experts (MoE) Model for Watermelon Sweetness Prediction") parser.add_argument( "--data_dir", type=str, default="../cleaned", help="Path to the cleaned dataset directory" ) parser.add_argument( "--model_dir", type=str, default="test_models", help="Directory containing model checkpoints" ) parser.add_argument( "--weighting", type=str, choices=["uniform", "performance"], default="uniform", help="How to weight the models (uniform or based on performance)" ) args = parser.parse_args() # Determine weights based on argument weights = None if args.weighting == "performance": # Weights inversely proportional to the MAE (better models get higher weights) # These are the MAE values from the provided results mae_values = [0.3635, 0.3765, 0.3959] # efficientnet_b3+transformer, efficientnet_b0+transformer, resnet50+transformer # Convert to weights (inverse of MAE, normalized) inverse_mae = [1/mae for mae in mae_values] total = sum(inverse_mae) weights = [val/total for val in inverse_mae] print(f"\033[92mINFO\033[0m: Using performance-based weights: {weights}") else: print(f"\033[92mINFO\033[0m: Using uniform weights") # Evaluate the MoE model evaluate_moe_model(args.data_dir, args.model_dir, weights)