File size: 3,388 Bytes
3604754
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""Test Backprop NEAT on 2D classification tasks."""

import os
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
from neat.datasets import (generate_xor_data, generate_circle_data, 
                         generate_spiral_data, plot_dataset)
from neat.backprop_neat import BackpropNEAT

def train_and_visualize(neat: BackpropNEAT, x: jnp.ndarray, y: jnp.ndarray,
                       dataset_name: str, viz_dir: str = 'visualizations'):
    """Train network and save visualizations."""
    os.makedirs(viz_dir, exist_ok=True)
    
    # Plot dataset
    plot_dataset(x, y, f'{dataset_name} Dataset')
    plt.savefig(os.path.join(viz_dir, f'{dataset_name}_dataset.png'))
    plt.close()
    
    # Training loop
    n_generations = 50
    n_epochs = 100
    
    for gen in range(n_generations):
        # Train networks with backprop
        neat.train_networks(x, y, n_epochs=n_epochs)
        
        # Evaluate fitness
        neat.evaluate_fitness(x, y)
        
        # Get best network
        best_network = max(neat.population, key=lambda n: n.fitness)
        
        # Save visualizations every 10 generations
        if gen % 10 == 0:
            gen_dir = os.path.join(viz_dir, f'gen_{gen:03d}')
            os.makedirs(gen_dir, exist_ok=True)
            
            # Visualize network architecture
            best_network.visualize(
                save_path=os.path.join(gen_dir, f'{dataset_name}_network.png'))
            
            # Plot decision boundary
            plt.figure(figsize=(8, 8))
            
            # Create grid of points
            xx, yy = jnp.meshgrid(jnp.linspace(-1, 1, 100),
                                 jnp.linspace(-1, 1, 100))
            grid_points = jnp.stack([xx.ravel(), yy.ravel()], axis=1)
            
            # Get predictions
            predictions = jnp.array([best_network.forward(p)[0] 
                                   for p in grid_points])
            predictions = predictions.reshape(xx.shape)
            
            # Plot decision boundary
            plt.contourf(xx, yy, predictions, alpha=0.4,
                        levels=jnp.linspace(0, 1, 20))
            plot_dataset(x, y, f'{dataset_name} - Generation {gen}')
            plt.savefig(os.path.join(gen_dir, 
                                    f'{dataset_name}_decision_boundary.png'))
            plt.close()
        
        # Evolve population
        neat.evolve_population()
        
        print(f'Generation {gen}: Best Fitness = {best_network.fitness:.4f}')

def main():
    """Run experiments on different datasets."""
    # Parameters
    n_points = 50  # Points per quadrant/class
    noise_level = 0.1
    population_size = 50
    
    # Test on different datasets
    datasets = [
        ('XOR', generate_xor_data),
        ('Circle', generate_circle_data),
        ('Spiral', generate_spiral_data)
    ]
    
    for name, generator in datasets:
        print(f'\nTraining on {name} dataset:')
        
        # Generate dataset
        x, y = generator(n_points, noise_level)
        
        # Create and train NEAT
        neat = BackpropNEAT(n_inputs=2, n_outputs=1, 
                           population_size=population_size)
        
        # Train and visualize
        train_and_visualize(neat, x, y, name)

if __name__ == '__main__':
    main()