neat / backprop_test.py
eyad-silx's picture
Upload backprop_test.py with huggingface_hub
3604754 verified
raw
history blame
3.39 kB
"""Test Backprop NEAT on 2D classification tasks."""
import os
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
from neat.datasets import (generate_xor_data, generate_circle_data,
generate_spiral_data, plot_dataset)
from neat.backprop_neat import BackpropNEAT
def train_and_visualize(neat: BackpropNEAT, x: jnp.ndarray, y: jnp.ndarray,
dataset_name: str, viz_dir: str = 'visualizations'):
"""Train network and save visualizations."""
os.makedirs(viz_dir, exist_ok=True)
# Plot dataset
plot_dataset(x, y, f'{dataset_name} Dataset')
plt.savefig(os.path.join(viz_dir, f'{dataset_name}_dataset.png'))
plt.close()
# Training loop
n_generations = 50
n_epochs = 100
for gen in range(n_generations):
# Train networks with backprop
neat.train_networks(x, y, n_epochs=n_epochs)
# Evaluate fitness
neat.evaluate_fitness(x, y)
# Get best network
best_network = max(neat.population, key=lambda n: n.fitness)
# Save visualizations every 10 generations
if gen % 10 == 0:
gen_dir = os.path.join(viz_dir, f'gen_{gen:03d}')
os.makedirs(gen_dir, exist_ok=True)
# Visualize network architecture
best_network.visualize(
save_path=os.path.join(gen_dir, f'{dataset_name}_network.png'))
# Plot decision boundary
plt.figure(figsize=(8, 8))
# Create grid of points
xx, yy = jnp.meshgrid(jnp.linspace(-1, 1, 100),
jnp.linspace(-1, 1, 100))
grid_points = jnp.stack([xx.ravel(), yy.ravel()], axis=1)
# Get predictions
predictions = jnp.array([best_network.forward(p)[0]
for p in grid_points])
predictions = predictions.reshape(xx.shape)
# Plot decision boundary
plt.contourf(xx, yy, predictions, alpha=0.4,
levels=jnp.linspace(0, 1, 20))
plot_dataset(x, y, f'{dataset_name} - Generation {gen}')
plt.savefig(os.path.join(gen_dir,
f'{dataset_name}_decision_boundary.png'))
plt.close()
# Evolve population
neat.evolve_population()
print(f'Generation {gen}: Best Fitness = {best_network.fitness:.4f}')
def main():
"""Run experiments on different datasets."""
# Parameters
n_points = 50 # Points per quadrant/class
noise_level = 0.1
population_size = 50
# Test on different datasets
datasets = [
('XOR', generate_xor_data),
('Circle', generate_circle_data),
('Spiral', generate_spiral_data)
]
for name, generator in datasets:
print(f'\nTraining on {name} dataset:')
# Generate dataset
x, y = generator(n_points, noise_level)
# Create and train NEAT
neat = BackpropNEAT(n_inputs=2, n_outputs=1,
population_size=population_size)
# Train and visualize
train_and_visualize(neat, x, y, name)
if __name__ == '__main__':
main()