Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import torchaudio | |
import torchvision | |
import numpy as np | |
import time | |
import json | |
from torch.utils.data import Dataset, DataLoader | |
import sys | |
from tqdm import tqdm | |
# Add parent directory to path to import the preprocess functions | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
from preprocess import process_audio_data, process_image_data | |
# Print library versions | |
print(f"\033[92mINFO\033[0m: PyTorch version: {torch.__version__}") | |
print(f"\033[92mINFO\033[0m: Torchaudio version: {torchaudio.__version__}") | |
print(f"\033[92mINFO\033[0m: Torchvision version: {torchvision.__version__}") | |
# Device selection | |
device = torch.device( | |
"cuda" | |
if torch.cuda.is_available() | |
else "mps" if torch.backends.mps.is_available() else "cpu" | |
) | |
print(f"\033[92mINFO\033[0m: Using device: {device}") | |
# Hyperparameters | |
batch_size = 16 | |
epochs = 1 # Just one epoch for evaluation | |
learning_rate = 0.0001 | |
class WatermelonDataset(Dataset): | |
def __init__(self, data_dir): | |
self.data_dir = data_dir | |
self.samples = [] | |
# Walk through the directory structure | |
for sweetness_dir in os.listdir(data_dir): | |
sweetness = float(sweetness_dir) | |
sweetness_path = os.path.join(data_dir, sweetness_dir) | |
if os.path.isdir(sweetness_path): | |
for id_dir in os.listdir(sweetness_path): | |
id_path = os.path.join(sweetness_path, id_dir) | |
if os.path.isdir(id_path): | |
audio_file = os.path.join(id_path, f"{id_dir}.wav") | |
image_file = os.path.join(id_path, f"{id_dir}.jpg") | |
if os.path.exists(audio_file) and os.path.exists(image_file): | |
self.samples.append((audio_file, image_file, sweetness)) | |
print(f"\033[92mINFO\033[0m: Loaded {len(self.samples)} samples from {data_dir}") | |
def __len__(self): | |
return len(self.samples) | |
def __getitem__(self, idx): | |
audio_path, image_path, label = self.samples[idx] | |
# Load and process audio | |
try: | |
waveform, sample_rate = torchaudio.load(audio_path) | |
mfcc = process_audio_data(waveform, sample_rate) | |
# Load and process image | |
image = torchvision.io.read_image(image_path) | |
image = image.float() | |
processed_image = process_image_data(image) | |
return mfcc, processed_image, torch.tensor(label).float() | |
except Exception as e: | |
print(f"\033[91mERR!\033[0m: Error processing sample {idx}: {e}") | |
# Return a fallback sample or skip this sample | |
# For simplicity, we'll return the first sample again | |
if idx == 0: # Prevent infinite recursion | |
raise e | |
return self.__getitem__(0) | |
# Define available backbone models | |
IMAGE_BACKBONES = { | |
"resnet50": { | |
"model": torchvision.models.resnet50, | |
"weights": torchvision.models.ResNet50_Weights.DEFAULT, | |
"output_dim": lambda model: model.fc.in_features | |
}, | |
"efficientnet_b0": { | |
"model": torchvision.models.efficientnet_b0, | |
"weights": torchvision.models.EfficientNet_B0_Weights.DEFAULT, | |
"output_dim": lambda model: model.classifier[1].in_features | |
}, | |
"efficientnet_b3": { | |
"model": torchvision.models.efficientnet_b3, | |
"weights": torchvision.models.EfficientNet_B3_Weights.DEFAULT, | |
"output_dim": lambda model: model.classifier[1].in_features | |
} | |
} | |
AUDIO_BACKBONES = { | |
"lstm": { | |
"model": lambda input_size, hidden_size: torch.nn.LSTM( | |
input_size=input_size, hidden_size=hidden_size, num_layers=2, batch_first=True | |
), | |
"output_dim": lambda hidden_size: hidden_size | |
}, | |
"gru": { | |
"model": lambda input_size, hidden_size: torch.nn.GRU( | |
input_size=input_size, hidden_size=hidden_size, num_layers=2, batch_first=True | |
), | |
"output_dim": lambda hidden_size: hidden_size | |
}, | |
"bidirectional_lstm": { | |
"model": lambda input_size, hidden_size: torch.nn.LSTM( | |
input_size=input_size, hidden_size=hidden_size, num_layers=2, batch_first=True, bidirectional=True | |
), | |
"output_dim": lambda hidden_size: hidden_size * 2 # * 2 because bidirectional | |
}, | |
"transformer": { | |
"model": lambda input_size, hidden_size: torch.nn.TransformerEncoder( | |
torch.nn.TransformerEncoderLayer( | |
d_model=input_size, nhead=8, dim_feedforward=hidden_size, batch_first=True | |
), | |
num_layers=2 | |
), | |
"output_dim": lambda hidden_size: 376 # Using input_size (mfcc dimensions) | |
} | |
} | |
class WatermelonModelModular(torch.nn.Module): | |
def __init__(self, image_backbone_name, audio_backbone_name, audio_hidden_size=128): | |
super(WatermelonModelModular, self).__init__() | |
# Audio backbone setup | |
self.audio_backbone_name = audio_backbone_name | |
self.audio_hidden_size = audio_hidden_size | |
self.audio_input_size = 376 # From MFCC dimensions | |
audio_config = AUDIO_BACKBONES[audio_backbone_name] | |
self.audio_backbone = audio_config["model"](self.audio_input_size, self.audio_hidden_size) | |
audio_output_dim = audio_config["output_dim"](self.audio_hidden_size) | |
self.audio_fc = torch.nn.Linear(audio_output_dim, 128) | |
# Image backbone setup | |
self.image_backbone_name = image_backbone_name | |
image_config = IMAGE_BACKBONES[image_backbone_name] | |
self.image_backbone = image_config["model"](weights=image_config["weights"]) | |
# Replace final layer for all image backbones to get features | |
if image_backbone_name.startswith("resnet"): | |
self.image_output_dim = image_config["output_dim"](self.image_backbone) | |
self.image_backbone.fc = torch.nn.Identity() | |
elif image_backbone_name.startswith("efficientnet"): | |
self.image_output_dim = image_config["output_dim"](self.image_backbone) | |
self.image_backbone.classifier = torch.nn.Identity() | |
elif image_backbone_name.startswith("convnext"): | |
self.image_output_dim = image_config["output_dim"](self.image_backbone) | |
self.image_backbone.classifier = torch.nn.Identity() | |
elif image_backbone_name.startswith("swin"): | |
self.image_output_dim = image_config["output_dim"](self.image_backbone) | |
self.image_backbone.head = torch.nn.Identity() | |
self.image_fc = torch.nn.Linear(self.image_output_dim, 128) | |
# Fully connected layers for final prediction | |
self.fc1 = torch.nn.Linear(256, 64) | |
self.fc2 = torch.nn.Linear(64, 1) | |
self.relu = torch.nn.ReLU() | |
def forward(self, mfcc, image): | |
# Audio backbone processing | |
if self.audio_backbone_name == "lstm" or self.audio_backbone_name == "gru": | |
audio_output, _ = self.audio_backbone(mfcc) | |
audio_output = audio_output[:, -1, :] # Use the output of the last time step | |
elif self.audio_backbone_name == "bidirectional_lstm": | |
audio_output, _ = self.audio_backbone(mfcc) | |
audio_output = audio_output[:, -1, :] # Use the output of the last time step | |
elif self.audio_backbone_name == "transformer": | |
audio_output = self.audio_backbone(mfcc) | |
audio_output = audio_output.mean(dim=1) # Average pooling over sequence length | |
audio_output = self.audio_fc(audio_output) | |
# Image backbone processing | |
image_output = self.image_backbone(image) | |
image_output = self.image_fc(image_output) | |
# Concatenate audio and image outputs | |
merged = torch.cat((audio_output, image_output), dim=1) | |
# Fully connected layers | |
output = self.relu(self.fc1(merged)) | |
output = self.fc2(output) | |
return output | |
def evaluate_model(data_dir, image_backbone, audio_backbone, audio_hidden_size=128, save_model_dir=None): | |
# Adjust batch size based on model complexity to avoid OOM errors | |
adjusted_batch_size = batch_size | |
# Models that typically require more memory get smaller batch sizes | |
if image_backbone in ["swin_b", "convnext_base"] or audio_backbone in ["transformer", "bidirectional_lstm"]: | |
adjusted_batch_size = max(4, batch_size // 2) # At least batch size of 4, but reduce by half if needed | |
print(f"\033[92mINFO\033[0m: Adjusted batch size to {adjusted_batch_size} for larger model") | |
# Create dataset | |
dataset = WatermelonDataset(data_dir) | |
n_samples = len(dataset) | |
# Split dataset | |
train_size = int(0.7 * n_samples) | |
val_size = int(0.2 * n_samples) | |
test_size = n_samples - train_size - val_size | |
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split( | |
dataset, [train_size, val_size, test_size] | |
) | |
train_loader = DataLoader(train_dataset, batch_size=adjusted_batch_size, shuffle=True) | |
val_loader = DataLoader(val_dataset, batch_size=adjusted_batch_size, shuffle=False) | |
test_loader = DataLoader(test_dataset, batch_size=adjusted_batch_size, shuffle=False) | |
# Initialize model | |
model = WatermelonModelModular(image_backbone, audio_backbone, audio_hidden_size).to(device) | |
# Loss function and optimizer | |
criterion = torch.nn.MSELoss() | |
mae_criterion = torch.nn.L1Loss() # For MAE evaluation | |
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | |
print(f"\033[92mINFO\033[0m: Evaluating model with {image_backbone} (image) and {audio_backbone} (audio)") | |
print(f"\033[92mINFO\033[0m: Training samples: {len(train_dataset)}") | |
print(f"\033[92mINFO\033[0m: Validation samples: {len(val_dataset)}") | |
print(f"\033[92mINFO\033[0m: Test samples: {len(test_dataset)}") | |
print(f"\033[92mINFO\033[0m: Batch size: {adjusted_batch_size}") | |
# Training loop | |
print(f"\033[92mINFO\033[0m: Training for evaluation...") | |
model.train() | |
running_loss = 0.0 | |
# Wrap with tqdm for progress visualization | |
train_iterator = tqdm(train_loader, desc="Training") | |
for i, (mfcc, image, label) in enumerate(train_iterator): | |
try: | |
mfcc, image, label = mfcc.to(device), image.to(device), label.to(device) | |
optimizer.zero_grad() | |
output = model(mfcc, image) | |
label = label.view(-1, 1).float() | |
loss = criterion(output, label) | |
loss.backward() | |
optimizer.step() | |
running_loss += loss.item() | |
train_iterator.set_postfix({"Loss": f"{loss.item():.4f}"}) | |
# Clear memory after each batch | |
if device.type == 'cuda': | |
del mfcc, image, label, output, loss | |
torch.cuda.empty_cache() | |
except Exception as e: | |
print(f"\033[91mERR!\033[0m: Error in training batch {i}: {e}") | |
# Clear memory in case of error | |
if device.type == 'cuda': | |
torch.cuda.empty_cache() | |
continue | |
# Validation phase | |
print(f"\033[92mINFO\033[0m: Validating...") | |
model.eval() | |
val_loss = 0.0 | |
val_mae = 0.0 | |
val_iterator = tqdm(val_loader, desc="Validation") | |
with torch.no_grad(): | |
for i, (mfcc, image, label) in enumerate(val_iterator): | |
try: | |
mfcc, image, label = mfcc.to(device), image.to(device), label.to(device) | |
output = model(mfcc, image) | |
label = label.view(-1, 1).float() | |
# Calculate MSE loss | |
loss = criterion(output, label) | |
val_loss += loss.item() | |
# Calculate MAE | |
mae = mae_criterion(output, label) | |
val_mae += mae.item() | |
val_iterator.set_postfix({"MSE": f"{loss.item():.4f}", "MAE": f"{mae.item():.4f}"}) | |
# Clear memory after each batch | |
if device.type == 'cuda': | |
del mfcc, image, label, output, loss, mae | |
torch.cuda.empty_cache() | |
except Exception as e: | |
print(f"\033[91mERR!\033[0m: Error in validation batch {i}: {e}") | |
# Clear memory in case of error | |
if device.type == 'cuda': | |
torch.cuda.empty_cache() | |
continue | |
avg_val_loss = val_loss / len(val_loader) if len(val_loader) > 0 else float('inf') | |
avg_val_mae = val_mae / len(val_loader) if len(val_loader) > 0 else float('inf') | |
# Test phase | |
print(f"\033[92mINFO\033[0m: Testing...") | |
model.eval() | |
test_loss = 0.0 | |
test_mae = 0.0 | |
test_iterator = tqdm(test_loader, desc="Testing") | |
with torch.no_grad(): | |
for i, (mfcc, image, label) in enumerate(test_iterator): | |
try: | |
mfcc, image, label = mfcc.to(device), image.to(device), label.to(device) | |
output = model(mfcc, image) | |
label = label.view(-1, 1).float() | |
# Calculate MSE loss | |
loss = criterion(output, label) | |
test_loss += loss.item() | |
# Calculate MAE | |
mae = mae_criterion(output, label) | |
test_mae += mae.item() | |
test_iterator.set_postfix({"MSE": f"{loss.item():.4f}", "MAE": f"{mae.item():.4f}"}) | |
# Clear memory after each batch | |
if device.type == 'cuda': | |
del mfcc, image, label, output, loss, mae | |
torch.cuda.empty_cache() | |
except Exception as e: | |
print(f"\033[91mERR!\033[0m: Error in test batch {i}: {e}") | |
# Clear memory in case of error | |
if device.type == 'cuda': | |
torch.cuda.empty_cache() | |
continue | |
avg_test_loss = test_loss / len(test_loader) if len(test_loader) > 0 else float('inf') | |
avg_test_mae = test_mae / len(test_loader) if len(test_loader) > 0 else float('inf') | |
results = { | |
"image_backbone": image_backbone, | |
"audio_backbone": audio_backbone, | |
"validation_mse": avg_val_loss, | |
"validation_mae": avg_val_mae, | |
"test_mse": avg_test_loss, | |
"test_mae": avg_test_mae | |
} | |
print(f"\033[92mINFO\033[0m: Evaluation Results:") | |
print(f"Image Backbone: {image_backbone}") | |
print(f"Audio Backbone: {audio_backbone}") | |
print(f"Validation MSE: {avg_val_loss:.4f}") | |
print(f"Validation MAE: {avg_val_mae:.4f}") | |
print(f"Test MSE: {avg_test_loss:.4f}") | |
print(f"Test MAE: {avg_test_mae:.4f}") | |
# Save model if save_model_dir is provided | |
if save_model_dir: | |
os.makedirs(save_model_dir, exist_ok=True) | |
model_filename = f"{image_backbone}_{audio_backbone}_model.pt" | |
model_path = os.path.join(save_model_dir, model_filename) | |
torch.save(model.state_dict(), model_path) | |
print(f"\033[92mINFO\033[0m: Model saved to {model_path}") | |
# Add model path to results | |
results["model_path"] = model_path | |
# Clean up memory before returning | |
if device.type == 'cuda': | |
del model, optimizer, criterion, mae_criterion | |
torch.cuda.empty_cache() | |
return results | |
def evaluate_all_combinations(data_dir, image_backbones=None, audio_backbones=None, save_model_dir="test_models", results_file="backbone_evaluation_results.json"): | |
if image_backbones is None: | |
image_backbones = list(IMAGE_BACKBONES.keys()) | |
if audio_backbones is None: | |
audio_backbones = list(AUDIO_BACKBONES.keys()) | |
# Create directory for saving models | |
if save_model_dir: | |
os.makedirs(save_model_dir, exist_ok=True) | |
# Load previous results if the file exists | |
results = [] | |
evaluated_combinations = set() | |
if os.path.exists(results_file): | |
try: | |
with open(results_file, 'r') as f: | |
results = json.load(f) | |
evaluated_combinations = {(r["image_backbone"], r["audio_backbone"]) for r in results} | |
print(f"\033[92mINFO\033[0m: Loaded {len(results)} previous results from {results_file}") | |
except Exception as e: | |
print(f"\033[91mERR!\033[0m: Error loading previous results from {results_file}: {e}") | |
results = [] | |
evaluated_combinations = set() | |
else: | |
print(f"\033[93mWARN\033[0m: Results file '{results_file}' does not exist. Starting with empty results.") | |
# Create combinations to evaluate, skipping any that have already been evaluated | |
combinations = [(img, aud) for img in image_backbones for aud in audio_backbones | |
if (img, aud) not in evaluated_combinations] | |
if len(combinations) < len(image_backbones) * len(audio_backbones): | |
print(f"\033[92mINFO\033[0m: Skipping {len(evaluated_combinations)} already evaluated combinations") | |
print(f"\033[92mINFO\033[0m: Will evaluate {len(combinations)} combinations") | |
for image_backbone, audio_backbone in combinations: | |
print(f"\033[92mINFO\033[0m: Evaluating {image_backbone} + {audio_backbone}") | |
try: | |
# Clean GPU memory before each model evaluation | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
print(f"\033[92mINFO\033[0m: CUDA memory cleared before evaluation") | |
# Print memory usage for debugging | |
print(f"\033[92mINFO\033[0m: CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") | |
print(f"\033[92mINFO\033[0m: CUDA memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB") | |
result = evaluate_model(data_dir, image_backbone, audio_backbone, save_model_dir=save_model_dir) | |
results.append(result) | |
# Save results after each evaluation | |
save_results(results, results_file) | |
print(f"\033[92mINFO\033[0m: Updated results saved to {results_file}") | |
# Force garbage collection to free memory | |
import gc | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
print(f"\033[92mINFO\033[0m: CUDA memory cleared after evaluation") | |
# Print memory usage for debugging | |
print(f"\033[92mINFO\033[0m: CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") | |
print(f"\033[92mINFO\033[0m: CUDA memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB") | |
except Exception as e: | |
print(f"\033[91mERR!\033[0m: Error evaluating {image_backbone} + {audio_backbone}: {e}") | |
print(f"\033[91mERR!\033[0m: To continue from this point, use --start_from={image_backbone}:{audio_backbone}") | |
# Force garbage collection to free memory even if there's an error | |
import gc | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
print(f"\033[92mINFO\033[0m: CUDA memory cleared after error") | |
continue | |
# Sort results by test MAE (ascending) | |
results.sort(key=lambda x: x["test_mae"]) | |
# Save final sorted results | |
save_results(results, results_file) | |
print("\n\033[92mINFO\033[0m: === FINAL RESULTS (Sorted by Test MAE) ===") | |
print(f"{'Image Backbone':<20} {'Audio Backbone':<20} {'Val MAE':<10} {'Test MAE':<10}") | |
print("="*60) | |
for result in results: | |
print(f"{result['image_backbone']:<20} {result['audio_backbone']:<20} {result['validation_mae']:<10.4f} {result['test_mae']:<10.4f}") | |
return results | |
def save_results(results, filename="backbone_evaluation_results.json"): | |
"""Save evaluation results to a JSON file.""" | |
with open(filename, 'w') as f: | |
json.dump(results, f, indent=4) | |
print(f"\033[92mINFO\033[0m: Results saved to {filename}") | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser(description="Evaluate Different Backbones for Watermelon Sweetness Prediction") | |
parser.add_argument( | |
"--data_dir", | |
type=str, | |
default="../cleaned", | |
help="Path to the cleaned dataset directory" | |
) | |
parser.add_argument( | |
"--image_backbone", | |
type=str, | |
default=None, | |
help="Specific image backbone to evaluate (leave empty to evaluate all available)" | |
) | |
parser.add_argument( | |
"--audio_backbone", | |
type=str, | |
default=None, | |
help="Specific audio backbone to evaluate (leave empty to evaluate all available)" | |
) | |
parser.add_argument( | |
"--evaluate_all", | |
action="store_true", | |
help="Evaluate all combinations of backbones" | |
) | |
parser.add_argument( | |
"--start_from", | |
type=str, | |
default=None, | |
help="Start evaluation from a specific combination, format: 'image_backbone:audio_backbone'" | |
) | |
parser.add_argument( | |
"--prioritize_efficient", | |
action="store_true", | |
help="Prioritize more efficient models first to avoid memory issues" | |
) | |
parser.add_argument( | |
"--results_file", | |
type=str, | |
default="backbone_evaluation_results.json", | |
help="File to save the evaluation results" | |
) | |
parser.add_argument( | |
"--load_previous_results", | |
action="store_true", | |
help="Load previous results from results_file if it exists" | |
) | |
parser.add_argument( | |
"--model_dir", | |
type=str, | |
default="test_models", | |
help="Directory to save model checkpoints" | |
) | |
args = parser.parse_args() | |
# Create model directory if it doesn't exist | |
if args.model_dir: | |
os.makedirs(args.model_dir, exist_ok=True) | |
print(f"\033[92mINFO\033[0m: === Available Image Backbones ===") | |
for name in IMAGE_BACKBONES.keys(): | |
print(f"- {name}") | |
print(f"\033[92mINFO\033[0m: === Available Audio Backbones ===") | |
for name in AUDIO_BACKBONES.keys(): | |
print(f"- {name}") | |
if args.evaluate_all: | |
evaluate_all_combinations(args.data_dir, results_file=args.results_file, save_model_dir=args.model_dir) | |
elif args.image_backbone and args.audio_backbone: | |
result = evaluate_model(args.data_dir, args.image_backbone, args.audio_backbone, save_model_dir=args.model_dir) | |
save_results([result], args.results_file) | |
else: | |
# Define a default set of backbones to evaluate if not specified | |
if args.prioritize_efficient: | |
# Start with less memory-intensive models | |
image_backbones = ["resnet50", "efficientnet_b0", "resnet101", "efficientnet_b3", "convnext_base", "swin_b"] | |
audio_backbones = ["lstm", "gru", "bidirectional_lstm", "transformer"] | |
else: | |
# Default selection focusing on better performance models | |
image_backbones = ["resnet101", "efficientnet_b3", "swin_b"] | |
audio_backbones = ["lstm", "bidirectional_lstm", "transformer"] | |
# Create all combinations | |
combinations = [(img, aud) for img in image_backbones for aud in audio_backbones] | |
# Load previous results if requested and file exists | |
previous_results = [] | |
previous_combinations = set() | |
if args.load_previous_results: | |
try: | |
if os.path.exists(args.results_file): | |
with open(args.results_file, 'r') as f: | |
previous_results = json.load(f) | |
previous_combinations = {(r["image_backbone"], r["audio_backbone"]) for r in previous_results} | |
print(f"\033[92mINFO\033[0m: Loaded {len(previous_results)} previous results") | |
else: | |
print(f"\033[93mWARN\033[0m: Results file '{args.results_file}' does not exist. Starting with empty results.") | |
except Exception as e: | |
print(f"\033[91mERR!\033[0m: Error loading previous results: {e}") | |
previous_results = [] | |
previous_combinations = set() | |
# If starting from a specific point | |
if args.start_from: | |
try: | |
start_img, start_aud = args.start_from.split(':') | |
start_idx = combinations.index((start_img, start_aud)) | |
combinations = combinations[start_idx:] | |
print(f"\033[92mINFO\033[0m: Starting from combination: {start_img} (image) + {start_aud} (audio)") | |
except (ValueError, IndexError): | |
print(f"\033[91mERR!\033[0m: Invalid start_from format or combination not found. Format should be 'image_backbone:audio_backbone'") | |
print(f"\033[91mERR!\033[0m: Continuing with all combinations.") | |
# Skip combinations that have already been evaluated | |
if previous_combinations: | |
original_count = len(combinations) | |
combinations = [(img, aud) for img, aud in combinations if (img, aud) not in previous_combinations] | |
print(f"\033[92mINFO\033[0m: Skipping {original_count - len(combinations)} already evaluated combinations") | |
# Evaluate each combination | |
results = previous_results.copy() | |
for img_backbone, audio_backbone in combinations: | |
print(f"\033[92mINFO\033[0m: Evaluating {img_backbone} + {audio_backbone}") | |
try: | |
# Clean GPU memory before each model evaluation | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
print(f"\033[92mINFO\033[0m: CUDA memory cleared before evaluation") | |
print(f"\033[92mINFO\033[0m: CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") | |
print(f"\033[92mINFO\033[0m: CUDA memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB") | |
result = evaluate_model(args.data_dir, img_backbone, audio_backbone, save_model_dir=args.model_dir) | |
results.append(result) | |
# Save results after each evaluation | |
save_results(results, args.results_file) | |
# Force garbage collection to free memory | |
import gc | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
print(f"\033[92mINFO\033[0m: CUDA memory cleared after evaluation") | |
print(f"\033[92mINFO\033[0m: CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") | |
print(f"\033[92mINFO\033[0m: CUDA memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB") | |
except Exception as e: | |
print(f"\033[91mERR!\033[0m: Error evaluating {img_backbone} + {audio_backbone}: {e}") | |
print(f"\033[91mERR!\033[0m: To continue from this point later, use --start_from={img_backbone}:{audio_backbone}") | |
# Force garbage collection to free memory even if there's an error | |
import gc | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
print(f"\033[92mINFO\033[0m: CUDA memory cleared after error") | |
continue | |
# Sort results by test MAE (ascending) | |
results.sort(key=lambda x: x["test_mae"]) | |
# Save final sorted results | |
save_results(results, args.results_file) | |
print("\n\033[92mINFO\033[0m: === FINAL RESULTS (Sorted by Test MAE) ===") | |
print(f"{'Image Backbone':<20} {'Audio Backbone':<20} {'Val MAE':<10} {'Test MAE':<10}") | |
print("="*60) | |
for result in results: | |
print(f"{result['image_backbone']:<20} {result['audio_backbone']:<20} {result['validation_mae']:<10.4f} {result['test_mae']:<10.4f}") |