"""Test Backprop NEAT on 2D classification tasks.""" import os import jax import jax.numpy as jnp import numpy as np import matplotlib.pyplot as plt import networkx as nx from neat.datasets import (generate_xor_data, generate_circle_data, generate_spiral_data, plot_dataset) from neat.backprop_neat import BackpropNEAT def train_and_visualize(neat: BackpropNEAT, x: jnp.ndarray, y: jnp.ndarray, dataset_name: str, viz_dir: str = 'visualizations'): """Train network and save visualizations.""" os.makedirs(viz_dir, exist_ok=True) # Plot dataset plot_dataset(x, y, f'{dataset_name} Dataset') plt.savefig(os.path.join(viz_dir, f'{dataset_name}_dataset.png')) plt.close() # Training loop n_generations = 50 n_epochs = 100 for gen in range(n_generations): # Train networks with backprop neat.train_networks(x, y, n_epochs=n_epochs) # Evaluate fitness neat.evaluate_fitness(x, y) # Get best network best_network = max(neat.population, key=lambda n: n.fitness) # Save visualizations every 10 generations if gen % 10 == 0: gen_dir = os.path.join(viz_dir, f'gen_{gen:03d}') os.makedirs(gen_dir, exist_ok=True) # Visualize network architecture best_network.visualize( save_path=os.path.join(gen_dir, f'{dataset_name}_network.png')) # Plot decision boundary plt.figure(figsize=(8, 8)) # Create grid of points xx, yy = jnp.meshgrid(jnp.linspace(-1, 1, 100), jnp.linspace(-1, 1, 100)) grid_points = jnp.stack([xx.ravel(), yy.ravel()], axis=1) # Get predictions predictions = jnp.array([best_network.forward(p)[0] for p in grid_points]) predictions = predictions.reshape(xx.shape) # Plot decision boundary plt.contourf(xx, yy, predictions, alpha=0.4, levels=jnp.linspace(0, 1, 20)) plot_dataset(x, y, f'{dataset_name} - Generation {gen}') plt.savefig(os.path.join(gen_dir, f'{dataset_name}_decision_boundary.png')) plt.close() # Evolve population neat.evolve_population() print(f'Generation {gen}: Best Fitness = {best_network.fitness:.4f}') def main(): """Run experiments on different datasets.""" # Parameters n_points = 50 # Points per quadrant/class noise_level = 0.1 population_size = 50 # Test on different datasets datasets = [ ('XOR', generate_xor_data), ('Circle', generate_circle_data), ('Spiral', generate_spiral_data) ] for name, generator in datasets: print(f'\nTraining on {name} dataset:') # Generate dataset x, y = generator(n_points, noise_level) # Create and train NEAT neat = BackpropNEAT(n_inputs=2, n_outputs=1, population_size=population_size) # Train and visualize train_and_visualize(neat, x, y, name) if __name__ == '__main__': main()