"""Train BackpropNEAT on Spiral dataset.""" import numpy as np import matplotlib.pyplot as plt import jax.numpy as jnp import jax import os import json from datetime import datetime from sklearn.model_selection import train_test_split from sklearn.utils import shuffle from neat.backprop_neat import BackpropNEAT from neat.datasets import generate_spiral_dataset from neat.network import Network from neat.genome import Genome class NetworkLogger: """Logger for tracking network evolution.""" def __init__(self, output_dir: str): self.output_dir = output_dir self.log_file = os.path.join(output_dir, "evolution_log.json") self.history = [] def log_network(self, epoch: int, network, loss: float, accuracy: float): """Log network state.""" network_state = { 'epoch': epoch, 'loss': float(loss), 'accuracy': float(accuracy), 'n_nodes': network.genome.n_nodes, 'n_connections': len(network.genome.connections), 'complexity_score': self.calculate_complexity(network), 'structure': self.get_network_structure(network), 'timestamp': datetime.now().isoformat() } self.history.append(network_state) # Save to file with open(self.log_file, 'w') as f: json.dump(self.history, f, indent=2) def calculate_complexity(self, network): """Calculate network complexity score.""" n_nodes = network.genome.n_nodes n_connections = len(network.genome.connections) return n_nodes * 0.5 + n_connections def get_network_structure(self, network): """Get detailed network structure.""" connections = [] for (src, dst), weight in network.genome.connections.items(): connections.append({ 'source': int(src), 'target': int(dst), 'weight': float(weight) }) return { 'input_size': network.genome.input_size, 'output_size': network.genome.output_size, 'hidden_nodes': network.genome.n_nodes - network.genome.input_size - network.genome.output_size, 'connections': connections } def plot_evolution(self, save_path: str): """Plot network evolution metrics.""" epochs = [log['epoch'] for log in self.history] accuracies = [log['accuracy'] for log in self.history] complexities = [log['complexity_score'] for log in self.history] losses = [log['loss'] for log in self.history] fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(12, 12)) # Plot accuracy ax1.plot(epochs, accuracies, 'b-', label='Accuracy') ax1.set_ylabel('Accuracy') ax1.set_title('Network Evolution') ax1.grid(True) ax1.legend() # Plot complexity ax2.plot(epochs, complexities, 'r-', label='Complexity Score') ax2.set_ylabel('Complexity Score') ax2.grid(True) ax2.legend() # Plot loss ax3.plot(epochs, losses, 'g-', label='Loss') ax3.set_ylabel('Loss') ax3.set_xlabel('Epoch') ax3.grid(True) ax3.legend() plt.tight_layout() plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.close() def visualize_dataset(X, y, network=None, title=None, save_path=None): """Visualize dataset with decision boundary.""" plt.figure(figsize=(10, 8)) if network is not None: # Create mesh grid x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5 y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5 xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100), np.linspace(y_min, y_max, 100)) # Make predictions X_mesh = jnp.array(np.c_[xx.ravel(), yy.ravel()], dtype=jnp.float32) Z = network.predict(X_mesh) Z = Z.reshape(xx.shape) # Plot decision boundary plt.contourf(xx, yy, Z, alpha=0.4, cmap='RdYlBu') plt.scatter(X[y == 1, 0], X[y == 1, 1], c='red', label='Class 1') plt.scatter(X[y == -1, 0], X[y == -1, 1], c='blue', label='Class -1') plt.grid(True) plt.legend() plt.title(title or 'Dataset') plt.xlabel('X1') plt.ylabel('X2') if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"Saved plot to {save_path}") else: plt.show() plt.close() def train_network(network, X, y, n_epochs=300, batch_size=32, patience=50): """Train a single network.""" print("Starting network training...") print(f"Input shape: {X.shape}, Output shape: {y.shape}") print(f"Network params: {network.params['weights'].keys()}") n_samples = len(X) n_batches = n_samples // batch_size best_accuracy = 0.0 patience_counter = 0 best_params = None # Convert to JAX arrays print("Converting to JAX arrays...") X = jnp.array(X, dtype=jnp.float32) y = jnp.array(y, dtype=jnp.float32) # Learning rate schedule base_lr = 0.01 warmup_epochs = 5 print(f"\nTraining for {n_epochs} epochs with {n_batches} batches per epoch") print(f"Batch size: {batch_size}, Patience: {patience}") for epoch in range(n_epochs): try: # Shuffle data perm = np.random.permutation(n_samples) X = X[perm] y = y[perm] # Adjust learning rate with warmup and cosine decay if epoch < warmup_epochs: lr = base_lr * (epoch + 1) / warmup_epochs else: # Cosine decay with restarts cycle_length = 50 cycle = (epoch - warmup_epochs) // cycle_length t = (epoch - warmup_epochs) % cycle_length lr = base_lr * 0.5 * (1 + np.cos(t * np.pi / cycle_length)) # Add small restart bump every cycle if t == 0: lr = base_lr * (0.9 ** cycle) epoch_loss = 0.0 # Train on mini-batches for i in range(n_batches): start_idx = i * batch_size end_idx = start_idx + batch_size X_batch = X[start_idx:end_idx] y_batch = y[start_idx:end_idx] try: # Update network parameters network.params, loss = network._train_step( network.params, X_batch, y_batch ) epoch_loss += loss except Exception as e: print(f"Error in batch {i}: {str(e)}") print(f"X_batch shape: {X_batch.shape}, y_batch shape: {y_batch.shape}") raise e # Compute training accuracy predictions = network.predict(X) train_accuracy = np.mean((predictions > 0) == (y > 0)) # Early stopping check if train_accuracy > best_accuracy: best_accuracy = train_accuracy best_params = {k: v.copy() for k, v in network.params.items()} patience_counter = 0 else: patience_counter += 1 # Print progress every epoch print(f"Epoch {epoch}: Train Acc = {train_accuracy:.4f}, Loss = {epoch_loss/n_batches:.4f}, LR = {lr:.6f}") # Early stopping if patience_counter >= patience: print(f"Early stopping at epoch {epoch}") break except Exception as e: print(f"Error in epoch {epoch}: {str(e)}") raise e # Restore best parameters if best_params is not None: network.params = best_params print(f"\nRestored best parameters with accuracy: {best_accuracy:.4f}") return best_accuracy def plot_decision_boundary(network, X, y, save_path): """Plot decision boundary with multiple views.""" fig, axes = plt.subplots(2, 2, figsize=(15, 15)) # Cartesian View x_min, x_max = X[:, 0].min() - 0.1, X[:, 0].max() + 0.1 y_min, y_max = X[:, 1].min() - 0.1, X[:, 1].max() + 0.1 xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100), np.linspace(y_min, y_max, 100)) # Create all features for prediction r = np.sqrt(xx**2 + yy**2) theta = np.arctan2(yy, xx) theta = np.unwrap(theta) dr_dtheta = r / theta # Normalize features x_norm = xx.ravel() / np.max(np.abs(X[:, 0])) y_norm = yy.ravel() / np.max(np.abs(X[:, 1])) r_norm = r.ravel() / np.max(X[:, 2] * np.max(np.abs(X[:, 0]))) theta_norm = theta.ravel() / (6 * np.pi) dr_norm = dr_dtheta.ravel() / np.max(np.abs(X[:, 4])) # Make predictions X_mesh = jnp.array(np.column_stack([ x_norm, y_norm, r_norm, theta_norm, dr_norm ]), dtype=jnp.float32) Z = network.predict(X_mesh) Z = Z.reshape(xx.shape) # Plot Cartesian view axes[0,0].contourf(xx, yy, Z, alpha=0.4, cmap='RdYlBu') axes[0,0].scatter(X[:, 0] * np.max(np.abs(X[:, 0])), X[:, 1] * np.max(np.abs(X[:, 1])), c=['red' if label == 1 else 'blue' for label in y], alpha=0.6) axes[0,0].set_title('Cartesian View') axes[0,0].grid(True) # Plot Polar view (θ vs r) axes[0,1].scatter(X[:, 3] * 6 * np.pi, # Denormalize theta X[:, 2] * np.max(np.abs(X[:, 0])), # Denormalize radius c=['red' if label == 1 else 'blue' for label in y], alpha=0.6) axes[0,1].set_title('Polar View (θ vs r)') axes[0,1].grid(True) # Plot dr/dθ vs θ axes[1,0].scatter(X[:, 3] * 6 * np.pi, # theta X[:, 4] * np.max(np.abs(X[:, 4])), # dr/dtheta c=['red' if label == 1 else 'blue' for label in y], alpha=0.6) axes[1,0].set_title('Spiral Tightness (dr/dθ vs θ)') axes[1,0].grid(True) # Plot r vs dr/dθ axes[1,1].scatter(X[:, 4] * np.max(np.abs(X[:, 4])), # dr/dtheta X[:, 2] * np.max(np.abs(X[:, 0])), # radius c=['red' if label == 1 else 'blue' for label in y], alpha=0.6) axes[1,1].set_title('Growth Rate (r vs dr/dθ)') axes[1,1].grid(True) plt.tight_layout() plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.close() def main(): """Main training loop.""" print("\nTraining on Spiral dataset...") # Generate spiral dataset X, y = generate_spiral_dataset(n_points=1000, noise=0.1) # Split data X_train, X_val, y_train, y_val = train_test_split( X, y, test_size=0.2, random_state=42 ) # Initialize BackpropNEAT with smaller network n_features = X.shape[1] neat = BackpropNEAT( n_inputs=n_features, n_outputs=1, n_hidden=32, # Reduced hidden layer size population_size=5, learning_rate=0.01, beta=0.9 ) # Training parameters n_epochs = 300 batch_size = 32 patience = 30 # Reduced patience # Train each network in the population best_network = None best_val_acc = 0.0 for i, network in enumerate(neat.population): print(f"\nTraining network {i+1}/{len(neat.population)}...") # Train network train_accuracy = train_network( network, X_train, y_train, n_epochs=n_epochs, batch_size=batch_size, patience=patience ) # Evaluate on validation set val_preds = network.predict(X_val) val_accuracy = np.mean((val_preds > 0) == (y_val > 0)) print(f"Network {i+1} - Train Acc: {train_accuracy:.4f}, Val Acc: {val_accuracy:.4f}") # Update best network if val_accuracy > best_val_acc: best_val_acc = val_accuracy best_network = network # Plot decision boundary for best network if best_network is not None: plot_path = "spiral_decision_boundary.png" plot_decision_boundary(best_network, X, y, plot_path) print(f"\nDecision boundary plot saved to {plot_path}") if __name__ == "__main__": main()