Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import torchaudio | |
import torchvision | |
import numpy as np | |
import json | |
from torch.utils.data import Dataset, DataLoader | |
import sys | |
from tqdm import tqdm | |
# Add parent directory to path to import the preprocess functions | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
from preprocess import process_audio_data, process_image_data | |
# Import the WatermelonDataset and WatermelonModelModular from the evaluate_backbones.py file | |
from evaluate_backbones import WatermelonDataset, WatermelonModelModular, IMAGE_BACKBONES, AUDIO_BACKBONES | |
# Print library versions | |
print(f"\033[92mINFO\033[0m: PyTorch version: {torch.__version__}") | |
print(f"\033[92mINFO\033[0m: Torchaudio version: {torchaudio.__version__}") | |
print(f"\033[92mINFO\033[0m: Torchvision version: {torchvision.__version__}") | |
# Device selection | |
device = torch.device( | |
"cuda" if torch.cuda.is_available() | |
else "mps" if torch.backends.mps.is_available() | |
else "cpu" | |
) | |
print(f"\033[92mINFO\033[0m: Using device: {device}") | |
# Define the top-performing models based on the previous evaluation | |
TOP_MODELS = [ | |
{"image_backbone": "efficientnet_b3", "audio_backbone": "transformer"}, | |
{"image_backbone": "efficientnet_b0", "audio_backbone": "transformer"}, | |
{"image_backbone": "resnet50", "audio_backbone": "transformer"} | |
] | |
# Define class for the MoE model | |
class WatermelonMoEModel(torch.nn.Module): | |
def __init__(self, model_configs, model_dir="models", weights=None): | |
""" | |
Mixture of Experts model that combines multiple backbone models. | |
Args: | |
model_configs: List of dictionaries with 'image_backbone' and 'audio_backbone' keys | |
model_dir: Directory where model checkpoints are stored | |
weights: Optional list of weights for each model (None for equal weighting) | |
""" | |
super(WatermelonMoEModel, self).__init__() | |
self.models = [] | |
self.model_configs = model_configs | |
# Load each model | |
for config in model_configs: | |
img_backbone = config["image_backbone"] | |
audio_backbone = config["audio_backbone"] | |
# Initialize model | |
model = WatermelonModelModular(img_backbone, audio_backbone) | |
# Load weights | |
model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt") | |
if os.path.exists(model_path): | |
print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}") | |
model.load_state_dict(torch.load(model_path, map_location=device)) | |
else: | |
print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}") | |
continue | |
model.to(device) | |
model.eval() # Set to evaluation mode | |
self.models.append(model) | |
# Set model weights (uniform by default) | |
if weights: | |
assert len(weights) == len(self.models), "Number of weights must match number of models" | |
self.weights = weights | |
else: | |
self.weights = [1.0 / len(self.models)] * len(self.models) | |
print(f"\033[92mINFO\033[0m: Loaded {len(self.models)} models for MoE ensemble") | |
print(f"\033[92mINFO\033[0m: Model weights: {self.weights}") | |
def forward(self, mfcc, image): | |
""" | |
Forward pass through the MoE model. | |
Returns the weighted average of all model outputs. | |
""" | |
outputs = [] | |
# Get outputs from each model | |
with torch.no_grad(): | |
for i, model in enumerate(self.models): | |
output = model(mfcc, image) | |
print(f"DEBUG: Model {i} output: {output}") | |
outputs.append(output * self.weights[i]) | |
# Return weighted average | |
final_output = torch.sum(torch.stack(outputs), dim=0) | |
print(f"DEBUG: Raw prediction: {final_output}") | |
return final_output | |
def evaluate_moe_model(data_dir, model_dir="models", weights=None): | |
""" | |
Evaluate the MoE model on the test set. | |
""" | |
# Load dataset | |
print(f"\033[92mINFO\033[0m: Loading dataset from {data_dir}") | |
dataset = WatermelonDataset(data_dir) | |
n_samples = len(dataset) | |
# Split dataset | |
train_size = int(0.7 * n_samples) | |
val_size = int(0.2 * n_samples) | |
test_size = n_samples - train_size - val_size | |
_, _, test_dataset = torch.utils.data.random_split( | |
dataset, [train_size, val_size, test_size] | |
) | |
# Use a reasonable batch size | |
batch_size = 8 | |
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) | |
# Initialize MoE model | |
moe_model = WatermelonMoEModel(TOP_MODELS, model_dir, weights) | |
moe_model.eval() | |
# Evaluation metrics | |
mae_criterion = torch.nn.L1Loss() | |
mse_criterion = torch.nn.MSELoss() | |
test_mae = 0.0 | |
test_mse = 0.0 | |
print(f"\033[92mINFO\033[0m: Evaluating MoE model on {len(test_dataset)} test samples") | |
# Individual model predictions for analysis | |
individual_predictions = {f"{config['image_backbone']}_{config['audio_backbone']}": [] | |
for config in TOP_MODELS} | |
true_labels = [] | |
moe_predictions = [] | |
# Evaluation loop | |
test_iterator = tqdm(test_loader, desc="Testing MoE") | |
with torch.no_grad(): | |
for i, (mfcc, image, label) in enumerate(test_iterator): | |
try: | |
mfcc, image, label = mfcc.to(device), image.to(device), label.to(device) | |
# Store individual model outputs for analysis | |
for j, model in enumerate(moe_model.models): | |
config = TOP_MODELS[j] | |
model_name = f"{config['image_backbone']}_{config['audio_backbone']}" | |
output = model(mfcc, image) | |
individual_predictions[model_name].extend(output.view(-1).cpu().numpy()) | |
print(f"DEBUG: Model {j} output: {output}") | |
# Get MoE prediction | |
output = moe_model(mfcc, image) | |
moe_predictions.extend(output.view(-1).cpu().numpy()) | |
print(f"DEBUG: MoE prediction: {output}") | |
# Store true labels | |
label = label.view(-1, 1).float() | |
true_labels.extend(label.view(-1).cpu().numpy()) | |
# Calculate metrics | |
mae = mae_criterion(output, label) | |
mse = mse_criterion(output, label) | |
test_mae += mae.item() | |
test_mse += mse.item() | |
test_iterator.set_postfix({"MAE": f"{mae.item():.4f}", "MSE": f"{mse.item():.4f}"}) | |
# Clean up memory | |
if device.type == 'cuda': | |
del mfcc, image, label, output, mae, mse | |
torch.cuda.empty_cache() | |
except Exception as e: | |
print(f"\033[91mERR!\033[0m: Error in test batch {i}: {e}") | |
if device.type == 'cuda': | |
torch.cuda.empty_cache() | |
continue | |
# Calculate average metrics | |
avg_test_mae = test_mae / len(test_loader) if len(test_loader) > 0 else float('inf') | |
avg_test_mse = test_mse / len(test_loader) if len(test_loader) > 0 else float('inf') | |
print(f"\n\033[92mINFO\033[0m: === MoE Model Results ===") | |
print(f"Test MAE: {avg_test_mae:.4f}") | |
print(f"Test MSE: {avg_test_mse:.4f}") | |
# Compare with individual models | |
print(f"\n\033[92mINFO\033[0m: === Comparison with Individual Models ===") | |
print(f"{'Model':<30} {'Test MAE':<15}") | |
print("="*45) | |
# Load previous results | |
results_file = "backbone_evaluation_results.json" | |
if os.path.exists(results_file): | |
with open(results_file, 'r') as f: | |
previous_results = json.load(f) | |
# Filter results for our top models | |
for config in TOP_MODELS: | |
img_backbone = config["image_backbone"] | |
audio_backbone = config["audio_backbone"] | |
for result in previous_results: | |
if result["image_backbone"] == img_backbone and result["audio_backbone"] == audio_backbone: | |
print(f"{img_backbone}_{audio_backbone:<20} {result['test_mae']:<15.4f}") | |
print(f"MoE (Ensemble) {avg_test_mae:<15.4f}") | |
# Save results and predictions | |
results = { | |
"moe_test_mae": float(avg_test_mae), | |
"moe_test_mse": float(avg_test_mse), | |
"true_labels": [float(x) for x in true_labels], | |
"moe_predictions": [float(x) for x in moe_predictions], | |
"individual_predictions": {key: [float(x) for x in values] | |
for key, values in individual_predictions.items()} | |
} | |
with open("moe_evaluation_results.json", 'w') as f: | |
json.dump(results, f, indent=4) | |
print(f"\033[92mINFO\033[0m: Results saved to moe_evaluation_results.json") | |
return avg_test_mae, avg_test_mse | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser(description="Test Mixture of Experts (MoE) Model for Watermelon Sweetness Prediction") | |
parser.add_argument( | |
"--data_dir", | |
type=str, | |
default="../cleaned", | |
help="Path to the cleaned dataset directory" | |
) | |
parser.add_argument( | |
"--model_dir", | |
type=str, | |
default="models", | |
help="Directory containing model checkpoints" | |
) | |
parser.add_argument( | |
"--weighting", | |
type=str, | |
choices=["uniform", "performance"], | |
default="uniform", | |
help="How to weight the models (uniform or based on performance)" | |
) | |
args = parser.parse_args() | |
# Determine weights based on argument | |
weights = None | |
if args.weighting == "performance": | |
# Weights inversely proportional to the MAE (better models get higher weights) | |
# These are the MAE values from the provided results | |
mae_values = [0.3635, 0.3765, 0.3959] # efficientnet_b3+transformer, efficientnet_b0+transformer, resnet50+transformer | |
# Convert to weights (inverse of MAE, normalized) | |
inverse_mae = [1/mae for mae in mae_values] | |
total = sum(inverse_mae) | |
weights = [val/total for val in inverse_mae] | |
print(f"\033[92mINFO\033[0m: Using performance-based weights: {weights}") | |
else: | |
print(f"\033[92mINFO\033[0m: Using uniform weights") | |
# Evaluate the MoE model | |
evaluate_moe_model(args.data_dir, args.model_dir, weights) |