watermelon2 / test_moe_model.py
Xalphinions's picture
Upload folder using huggingface_hub
6f4e394 verified
raw
history blame
10.6 kB
import os
import torch
import torchaudio
import torchvision
import numpy as np
import json
from torch.utils.data import Dataset, DataLoader
import sys
from tqdm import tqdm
# Add parent directory to path to import the preprocess functions
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from preprocess import process_audio_data, process_image_data
# Import the WatermelonDataset and WatermelonModelModular from the evaluate_backbones.py file
from evaluate_backbones import WatermelonDataset, WatermelonModelModular, IMAGE_BACKBONES, AUDIO_BACKBONES
# Print library versions
print(f"\033[92mINFO\033[0m: PyTorch version: {torch.__version__}")
print(f"\033[92mINFO\033[0m: Torchaudio version: {torchaudio.__version__}")
print(f"\033[92mINFO\033[0m: Torchvision version: {torchvision.__version__}")
# Device selection
device = torch.device(
"cuda" if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu"
)
print(f"\033[92mINFO\033[0m: Using device: {device}")
# Define the top-performing models based on the previous evaluation
TOP_MODELS = [
{"image_backbone": "efficientnet_b3", "audio_backbone": "transformer"},
{"image_backbone": "efficientnet_b0", "audio_backbone": "transformer"},
{"image_backbone": "resnet50", "audio_backbone": "transformer"}
]
# Define class for the MoE model
class WatermelonMoEModel(torch.nn.Module):
def __init__(self, model_configs, model_dir="test_models", weights=None):
"""
Mixture of Experts model that combines multiple backbone models.
Args:
model_configs: List of dictionaries with 'image_backbone' and 'audio_backbone' keys
model_dir: Directory where model checkpoints are stored
weights: Optional list of weights for each model (None for equal weighting)
"""
super(WatermelonMoEModel, self).__init__()
self.models = []
self.model_configs = model_configs
# Load each model
for config in model_configs:
img_backbone = config["image_backbone"]
audio_backbone = config["audio_backbone"]
# Initialize model
model = WatermelonModelModular(img_backbone, audio_backbone)
# Load weights
model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt")
if os.path.exists(model_path):
print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}")
model.load_state_dict(torch.load(model_path, map_location=device))
else:
print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}")
continue
model.to(device)
model.eval() # Set to evaluation mode
self.models.append(model)
# Set model weights (uniform by default)
if weights:
assert len(weights) == len(self.models), "Number of weights must match number of models"
self.weights = weights
else:
self.weights = [1.0 / len(self.models)] * len(self.models)
print(f"\033[92mINFO\033[0m: Loaded {len(self.models)} models for MoE ensemble")
print(f"\033[92mINFO\033[0m: Model weights: {self.weights}")
def forward(self, mfcc, image):
"""
Forward pass through the MoE model.
Returns the weighted average of all model outputs.
"""
outputs = []
# Get outputs from each model
with torch.no_grad():
for i, model in enumerate(self.models):
output = model(mfcc, image)
outputs.append(output * self.weights[i])
# Return weighted average
return torch.sum(torch.stack(outputs), dim=0)
def evaluate_moe_model(data_dir, model_dir="test_models", weights=None):
"""
Evaluate the MoE model on the test set.
"""
# Load dataset
print(f"\033[92mINFO\033[0m: Loading dataset from {data_dir}")
dataset = WatermelonDataset(data_dir)
n_samples = len(dataset)
# Split dataset
train_size = int(0.7 * n_samples)
val_size = int(0.2 * n_samples)
test_size = n_samples - train_size - val_size
_, _, test_dataset = torch.utils.data.random_split(
dataset, [train_size, val_size, test_size]
)
# Use a reasonable batch size
batch_size = 8
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# Initialize MoE model
moe_model = WatermelonMoEModel(TOP_MODELS, model_dir, weights)
moe_model.eval()
# Evaluation metrics
mae_criterion = torch.nn.L1Loss()
mse_criterion = torch.nn.MSELoss()
test_mae = 0.0
test_mse = 0.0
print(f"\033[92mINFO\033[0m: Evaluating MoE model on {len(test_dataset)} test samples")
# Individual model predictions for analysis
individual_predictions = {f"{config['image_backbone']}_{config['audio_backbone']}": []
for config in TOP_MODELS}
true_labels = []
moe_predictions = []
# Evaluation loop
test_iterator = tqdm(test_loader, desc="Testing MoE")
with torch.no_grad():
for i, (mfcc, image, label) in enumerate(test_iterator):
try:
mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
# Store individual model outputs for analysis
for j, model in enumerate(moe_model.models):
config = TOP_MODELS[j]
model_name = f"{config['image_backbone']}_{config['audio_backbone']}"
output = model(mfcc, image)
individual_predictions[model_name].extend(output.view(-1).cpu().numpy())
# Get MoE prediction
output = moe_model(mfcc, image)
moe_predictions.extend(output.view(-1).cpu().numpy())
# Store true labels
label = label.view(-1, 1).float()
true_labels.extend(label.view(-1).cpu().numpy())
# Calculate metrics
mae = mae_criterion(output, label)
mse = mse_criterion(output, label)
test_mae += mae.item()
test_mse += mse.item()
test_iterator.set_postfix({"MAE": f"{mae.item():.4f}", "MSE": f"{mse.item():.4f}"})
# Clean up memory
if device.type == 'cuda':
del mfcc, image, label, output, mae, mse
torch.cuda.empty_cache()
except Exception as e:
print(f"\033[91mERR!\033[0m: Error in test batch {i}: {e}")
if device.type == 'cuda':
torch.cuda.empty_cache()
continue
# Calculate average metrics
avg_test_mae = test_mae / len(test_loader) if len(test_loader) > 0 else float('inf')
avg_test_mse = test_mse / len(test_loader) if len(test_loader) > 0 else float('inf')
print(f"\n\033[92mINFO\033[0m: === MoE Model Results ===")
print(f"Test MAE: {avg_test_mae:.4f}")
print(f"Test MSE: {avg_test_mse:.4f}")
# Compare with individual models
print(f"\n\033[92mINFO\033[0m: === Comparison with Individual Models ===")
print(f"{'Model':<30} {'Test MAE':<15}")
print("="*45)
# Load previous results
results_file = "backbone_evaluation_results.json"
if os.path.exists(results_file):
with open(results_file, 'r') as f:
previous_results = json.load(f)
# Filter results for our top models
for config in TOP_MODELS:
img_backbone = config["image_backbone"]
audio_backbone = config["audio_backbone"]
for result in previous_results:
if result["image_backbone"] == img_backbone and result["audio_backbone"] == audio_backbone:
print(f"{img_backbone}_{audio_backbone:<20} {result['test_mae']:<15.4f}")
print(f"MoE (Ensemble) {avg_test_mae:<15.4f}")
# Save results and predictions
results = {
"moe_test_mae": float(avg_test_mae),
"moe_test_mse": float(avg_test_mse),
"true_labels": [float(x) for x in true_labels],
"moe_predictions": [float(x) for x in moe_predictions],
"individual_predictions": {key: [float(x) for x in values]
for key, values in individual_predictions.items()}
}
with open("moe_evaluation_results.json", 'w') as f:
json.dump(results, f, indent=4)
print(f"\033[92mINFO\033[0m: Results saved to moe_evaluation_results.json")
return avg_test_mae, avg_test_mse
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Test Mixture of Experts (MoE) Model for Watermelon Sweetness Prediction")
parser.add_argument(
"--data_dir",
type=str,
default="../cleaned",
help="Path to the cleaned dataset directory"
)
parser.add_argument(
"--model_dir",
type=str,
default="test_models",
help="Directory containing model checkpoints"
)
parser.add_argument(
"--weighting",
type=str,
choices=["uniform", "performance"],
default="uniform",
help="How to weight the models (uniform or based on performance)"
)
args = parser.parse_args()
# Determine weights based on argument
weights = None
if args.weighting == "performance":
# Weights inversely proportional to the MAE (better models get higher weights)
# These are the MAE values from the provided results
mae_values = [0.3635, 0.3765, 0.3959] # efficientnet_b3+transformer, efficientnet_b0+transformer, resnet50+transformer
# Convert to weights (inverse of MAE, normalized)
inverse_mae = [1/mae for mae in mae_values]
total = sum(inverse_mae)
weights = [val/total for val in inverse_mae]
print(f"\033[92mINFO\033[0m: Using performance-based weights: {weights}")
else:
print(f"\033[92mINFO\033[0m: Using uniform weights")
# Evaluate the MoE model
evaluate_moe_model(args.data_dir, args.model_dir, weights)