watermelon2 / test_moe_model.py
Xalphinions's picture
Upload folder using huggingface_hub
088f2ca verified
import os
import torch
import torchaudio
import torchvision
import numpy as np
import json
from torch.utils.data import Dataset, DataLoader
import sys
from tqdm import tqdm
# Add parent directory to path to import the preprocess functions
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from preprocess import process_audio_data, process_image_data
# Import the WatermelonDataset and WatermelonModelModular from the evaluate_backbones.py file
from evaluate_backbones import WatermelonDataset, WatermelonModelModular, IMAGE_BACKBONES, AUDIO_BACKBONES
# Print library versions
print(f"\033[92mINFO\033[0m: PyTorch version: {torch.__version__}")
print(f"\033[92mINFO\033[0m: Torchaudio version: {torchaudio.__version__}")
print(f"\033[92mINFO\033[0m: Torchvision version: {torchvision.__version__}")
# Device selection
device = torch.device(
"cuda" if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu"
)
print(f"\033[92mINFO\033[0m: Using device: {device}")
# Define the top-performing models based on the previous evaluation
TOP_MODELS = [
{"image_backbone": "efficientnet_b3", "audio_backbone": "transformer"},
{"image_backbone": "efficientnet_b0", "audio_backbone": "transformer"},
{"image_backbone": "resnet50", "audio_backbone": "transformer"}
]
# Define class for the MoE model
class WatermelonMoEModel(torch.nn.Module):
def __init__(self, model_configs, model_dir="models", weights=None):
"""
Mixture of Experts model that combines multiple backbone models.
Args:
model_configs: List of dictionaries with 'image_backbone' and 'audio_backbone' keys
model_dir: Directory where model checkpoints are stored
weights: Optional list of weights for each model (None for equal weighting)
"""
super(WatermelonMoEModel, self).__init__()
self.models = []
self.model_configs = model_configs
# Load each model
for config in model_configs:
img_backbone = config["image_backbone"]
audio_backbone = config["audio_backbone"]
# Initialize model
model = WatermelonModelModular(img_backbone, audio_backbone)
# Load weights
model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt")
if os.path.exists(model_path):
print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}")
model.load_state_dict(torch.load(model_path, map_location=device))
else:
print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}")
continue
model.to(device)
model.eval() # Set to evaluation mode
self.models.append(model)
# Set model weights (uniform by default)
if weights:
assert len(weights) == len(self.models), "Number of weights must match number of models"
self.weights = weights
else:
self.weights = [1.0 / len(self.models)] * len(self.models)
print(f"\033[92mINFO\033[0m: Loaded {len(self.models)} models for MoE ensemble")
print(f"\033[92mINFO\033[0m: Model weights: {self.weights}")
def forward(self, mfcc, image):
"""
Forward pass through the MoE model.
Returns the weighted average of all model outputs.
"""
outputs = []
# Get outputs from each model
with torch.no_grad():
for i, model in enumerate(self.models):
output = model(mfcc, image)
print(f"DEBUG: Model {i} output: {output}")
outputs.append(output * self.weights[i])
# Return weighted average
final_output = torch.sum(torch.stack(outputs), dim=0)
print(f"DEBUG: Raw prediction: {final_output}")
return final_output
def evaluate_moe_model(data_dir, model_dir="models", weights=None):
"""
Evaluate the MoE model on the test set.
"""
# Load dataset
print(f"\033[92mINFO\033[0m: Loading dataset from {data_dir}")
dataset = WatermelonDataset(data_dir)
n_samples = len(dataset)
# Split dataset
train_size = int(0.7 * n_samples)
val_size = int(0.2 * n_samples)
test_size = n_samples - train_size - val_size
_, _, test_dataset = torch.utils.data.random_split(
dataset, [train_size, val_size, test_size]
)
# Use a reasonable batch size
batch_size = 8
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# Initialize MoE model
moe_model = WatermelonMoEModel(TOP_MODELS, model_dir, weights)
moe_model.eval()
# Evaluation metrics
mae_criterion = torch.nn.L1Loss()
mse_criterion = torch.nn.MSELoss()
test_mae = 0.0
test_mse = 0.0
print(f"\033[92mINFO\033[0m: Evaluating MoE model on {len(test_dataset)} test samples")
# Individual model predictions for analysis
individual_predictions = {f"{config['image_backbone']}_{config['audio_backbone']}": []
for config in TOP_MODELS}
true_labels = []
moe_predictions = []
# Evaluation loop
test_iterator = tqdm(test_loader, desc="Testing MoE")
with torch.no_grad():
for i, (mfcc, image, label) in enumerate(test_iterator):
try:
mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
# Store individual model outputs for analysis
for j, model in enumerate(moe_model.models):
config = TOP_MODELS[j]
model_name = f"{config['image_backbone']}_{config['audio_backbone']}"
output = model(mfcc, image)
individual_predictions[model_name].extend(output.view(-1).cpu().numpy())
print(f"DEBUG: Model {j} output: {output}")
# Get MoE prediction
output = moe_model(mfcc, image)
moe_predictions.extend(output.view(-1).cpu().numpy())
print(f"DEBUG: MoE prediction: {output}")
# Store true labels
label = label.view(-1, 1).float()
true_labels.extend(label.view(-1).cpu().numpy())
# Calculate metrics
mae = mae_criterion(output, label)
mse = mse_criterion(output, label)
test_mae += mae.item()
test_mse += mse.item()
test_iterator.set_postfix({"MAE": f"{mae.item():.4f}", "MSE": f"{mse.item():.4f}"})
# Clean up memory
if device.type == 'cuda':
del mfcc, image, label, output, mae, mse
torch.cuda.empty_cache()
except Exception as e:
print(f"\033[91mERR!\033[0m: Error in test batch {i}: {e}")
if device.type == 'cuda':
torch.cuda.empty_cache()
continue
# Calculate average metrics
avg_test_mae = test_mae / len(test_loader) if len(test_loader) > 0 else float('inf')
avg_test_mse = test_mse / len(test_loader) if len(test_loader) > 0 else float('inf')
print(f"\n\033[92mINFO\033[0m: === MoE Model Results ===")
print(f"Test MAE: {avg_test_mae:.4f}")
print(f"Test MSE: {avg_test_mse:.4f}")
# Compare with individual models
print(f"\n\033[92mINFO\033[0m: === Comparison with Individual Models ===")
print(f"{'Model':<30} {'Test MAE':<15}")
print("="*45)
# Load previous results
results_file = "backbone_evaluation_results.json"
if os.path.exists(results_file):
with open(results_file, 'r') as f:
previous_results = json.load(f)
# Filter results for our top models
for config in TOP_MODELS:
img_backbone = config["image_backbone"]
audio_backbone = config["audio_backbone"]
for result in previous_results:
if result["image_backbone"] == img_backbone and result["audio_backbone"] == audio_backbone:
print(f"{img_backbone}_{audio_backbone:<20} {result['test_mae']:<15.4f}")
print(f"MoE (Ensemble) {avg_test_mae:<15.4f}")
# Save results and predictions
results = {
"moe_test_mae": float(avg_test_mae),
"moe_test_mse": float(avg_test_mse),
"true_labels": [float(x) for x in true_labels],
"moe_predictions": [float(x) for x in moe_predictions],
"individual_predictions": {key: [float(x) for x in values]
for key, values in individual_predictions.items()}
}
with open("moe_evaluation_results.json", 'w') as f:
json.dump(results, f, indent=4)
print(f"\033[92mINFO\033[0m: Results saved to moe_evaluation_results.json")
return avg_test_mae, avg_test_mse
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Test Mixture of Experts (MoE) Model for Watermelon Sweetness Prediction")
parser.add_argument(
"--data_dir",
type=str,
default="../cleaned",
help="Path to the cleaned dataset directory"
)
parser.add_argument(
"--model_dir",
type=str,
default="models",
help="Directory containing model checkpoints"
)
parser.add_argument(
"--weighting",
type=str,
choices=["uniform", "performance"],
default="uniform",
help="How to weight the models (uniform or based on performance)"
)
args = parser.parse_args()
# Determine weights based on argument
weights = None
if args.weighting == "performance":
# Weights inversely proportional to the MAE (better models get higher weights)
# These are the MAE values from the provided results
mae_values = [0.3635, 0.3765, 0.3959] # efficientnet_b3+transformer, efficientnet_b0+transformer, resnet50+transformer
# Convert to weights (inverse of MAE, normalized)
inverse_mae = [1/mae for mae in mae_values]
total = sum(inverse_mae)
weights = [val/total for val in inverse_mae]
print(f"\033[92mINFO\033[0m: Using performance-based weights: {weights}")
else:
print(f"\033[92mINFO\033[0m: Using uniform weights")
# Evaluate the MoE model
evaluate_moe_model(args.data_dir, args.model_dir, weights)