|
"""Test Backprop NEAT on 2D classification tasks.""" |
|
|
|
import os |
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import networkx as nx |
|
from neat.datasets import (generate_xor_data, generate_circle_data, |
|
generate_spiral_data, plot_dataset) |
|
from neat.backprop_neat import BackpropNEAT |
|
|
|
def train_and_visualize(neat: BackpropNEAT, x: jnp.ndarray, y: jnp.ndarray, |
|
dataset_name: str, viz_dir: str = 'visualizations'): |
|
"""Train network and save visualizations.""" |
|
os.makedirs(viz_dir, exist_ok=True) |
|
|
|
|
|
plot_dataset(x, y, f'{dataset_name} Dataset') |
|
plt.savefig(os.path.join(viz_dir, f'{dataset_name}_dataset.png')) |
|
plt.close() |
|
|
|
|
|
n_generations = 50 |
|
n_epochs = 100 |
|
|
|
for gen in range(n_generations): |
|
|
|
neat.train_networks(x, y, n_epochs=n_epochs) |
|
|
|
|
|
neat.evaluate_fitness(x, y) |
|
|
|
|
|
best_network = max(neat.population, key=lambda n: n.fitness) |
|
|
|
|
|
if gen % 10 == 0: |
|
gen_dir = os.path.join(viz_dir, f'gen_{gen:03d}') |
|
os.makedirs(gen_dir, exist_ok=True) |
|
|
|
|
|
best_network.visualize( |
|
save_path=os.path.join(gen_dir, f'{dataset_name}_network.png')) |
|
|
|
|
|
plt.figure(figsize=(8, 8)) |
|
|
|
|
|
xx, yy = jnp.meshgrid(jnp.linspace(-1, 1, 100), |
|
jnp.linspace(-1, 1, 100)) |
|
grid_points = jnp.stack([xx.ravel(), yy.ravel()], axis=1) |
|
|
|
|
|
predictions = jnp.array([best_network.forward(p)[0] |
|
for p in grid_points]) |
|
predictions = predictions.reshape(xx.shape) |
|
|
|
|
|
plt.contourf(xx, yy, predictions, alpha=0.4, |
|
levels=jnp.linspace(0, 1, 20)) |
|
plot_dataset(x, y, f'{dataset_name} - Generation {gen}') |
|
plt.savefig(os.path.join(gen_dir, |
|
f'{dataset_name}_decision_boundary.png')) |
|
plt.close() |
|
|
|
|
|
neat.evolve_population() |
|
|
|
print(f'Generation {gen}: Best Fitness = {best_network.fitness:.4f}') |
|
|
|
def main(): |
|
"""Run experiments on different datasets.""" |
|
|
|
n_points = 50 |
|
noise_level = 0.1 |
|
population_size = 50 |
|
|
|
|
|
datasets = [ |
|
('XOR', generate_xor_data), |
|
('Circle', generate_circle_data), |
|
('Spiral', generate_spiral_data) |
|
] |
|
|
|
for name, generator in datasets: |
|
print(f'\nTraining on {name} dataset:') |
|
|
|
|
|
x, y = generator(n_points, noise_level) |
|
|
|
|
|
neat = BackpropNEAT(n_inputs=2, n_outputs=1, |
|
population_size=population_size) |
|
|
|
|
|
train_and_visualize(neat, x, y, name) |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|