|
"""Train BackpropNEAT on Spiral dataset.""" |
|
|
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import jax.numpy as jnp |
|
import jax |
|
import os |
|
import json |
|
from datetime import datetime |
|
from sklearn.model_selection import train_test_split |
|
from sklearn.utils import shuffle |
|
|
|
from neat.backprop_neat import BackpropNEAT |
|
from neat.datasets import generate_spiral_dataset |
|
from neat.network import Network |
|
from neat.genome import Genome |
|
|
|
class NetworkLogger: |
|
"""Logger for tracking network evolution.""" |
|
|
|
def __init__(self, output_dir: str): |
|
self.output_dir = output_dir |
|
self.log_file = os.path.join(output_dir, "evolution_log.json") |
|
self.history = [] |
|
|
|
def log_network(self, epoch: int, network, loss: float, accuracy: float): |
|
"""Log network state.""" |
|
network_state = { |
|
'epoch': epoch, |
|
'loss': float(loss), |
|
'accuracy': float(accuracy), |
|
'n_nodes': network.genome.n_nodes, |
|
'n_connections': len(network.genome.connections), |
|
'complexity_score': self.calculate_complexity(network), |
|
'structure': self.get_network_structure(network), |
|
'timestamp': datetime.now().isoformat() |
|
} |
|
self.history.append(network_state) |
|
|
|
|
|
with open(self.log_file, 'w') as f: |
|
json.dump(self.history, f, indent=2) |
|
|
|
def calculate_complexity(self, network): |
|
"""Calculate network complexity score.""" |
|
n_nodes = network.genome.n_nodes |
|
n_connections = len(network.genome.connections) |
|
return n_nodes * 0.5 + n_connections |
|
|
|
def get_network_structure(self, network): |
|
"""Get detailed network structure.""" |
|
connections = [] |
|
for (src, dst), weight in network.genome.connections.items(): |
|
connections.append({ |
|
'source': int(src), |
|
'target': int(dst), |
|
'weight': float(weight) |
|
}) |
|
return { |
|
'input_size': network.genome.input_size, |
|
'output_size': network.genome.output_size, |
|
'hidden_nodes': network.genome.n_nodes - network.genome.input_size - network.genome.output_size, |
|
'connections': connections |
|
} |
|
|
|
def plot_evolution(self, save_path: str): |
|
"""Plot network evolution metrics.""" |
|
epochs = [log['epoch'] for log in self.history] |
|
accuracies = [log['accuracy'] for log in self.history] |
|
complexities = [log['complexity_score'] for log in self.history] |
|
losses = [log['loss'] for log in self.history] |
|
|
|
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(12, 12)) |
|
|
|
|
|
ax1.plot(epochs, accuracies, 'b-', label='Accuracy') |
|
ax1.set_ylabel('Accuracy') |
|
ax1.set_title('Network Evolution') |
|
ax1.grid(True) |
|
ax1.legend() |
|
|
|
|
|
ax2.plot(epochs, complexities, 'r-', label='Complexity Score') |
|
ax2.set_ylabel('Complexity Score') |
|
ax2.grid(True) |
|
ax2.legend() |
|
|
|
|
|
ax3.plot(epochs, losses, 'g-', label='Loss') |
|
ax3.set_ylabel('Loss') |
|
ax3.set_xlabel('Epoch') |
|
ax3.grid(True) |
|
ax3.legend() |
|
|
|
plt.tight_layout() |
|
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
plt.close() |
|
|
|
def visualize_dataset(X, y, network=None, title=None, save_path=None): |
|
"""Visualize dataset with decision boundary.""" |
|
plt.figure(figsize=(10, 8)) |
|
|
|
if network is not None: |
|
|
|
x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5 |
|
y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5 |
|
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100), |
|
np.linspace(y_min, y_max, 100)) |
|
|
|
|
|
X_mesh = jnp.array(np.c_[xx.ravel(), yy.ravel()], dtype=jnp.float32) |
|
Z = network.predict(X_mesh) |
|
Z = Z.reshape(xx.shape) |
|
|
|
|
|
plt.contourf(xx, yy, Z, alpha=0.4, cmap='RdYlBu') |
|
|
|
plt.scatter(X[y == 1, 0], X[y == 1, 1], c='red', label='Class 1') |
|
plt.scatter(X[y == -1, 0], X[y == -1, 1], c='blue', label='Class -1') |
|
plt.grid(True) |
|
plt.legend() |
|
plt.title(title or 'Dataset') |
|
plt.xlabel('X1') |
|
plt.ylabel('X2') |
|
|
|
if save_path: |
|
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
print(f"Saved plot to {save_path}") |
|
else: |
|
plt.show() |
|
plt.close() |
|
|
|
def train_network(network, X, y, n_epochs=300, batch_size=32, patience=50): |
|
"""Train a single network.""" |
|
print("Starting network training...") |
|
print(f"Input shape: {X.shape}, Output shape: {y.shape}") |
|
print(f"Network params: {network.params['weights'].keys()}") |
|
|
|
n_samples = len(X) |
|
n_batches = n_samples // batch_size |
|
best_accuracy = 0.0 |
|
patience_counter = 0 |
|
best_params = None |
|
|
|
|
|
print("Converting to JAX arrays...") |
|
X = jnp.array(X, dtype=jnp.float32) |
|
y = jnp.array(y, dtype=jnp.float32) |
|
|
|
|
|
base_lr = 0.01 |
|
warmup_epochs = 5 |
|
|
|
print(f"\nTraining for {n_epochs} epochs with {n_batches} batches per epoch") |
|
print(f"Batch size: {batch_size}, Patience: {patience}") |
|
|
|
for epoch in range(n_epochs): |
|
try: |
|
|
|
perm = np.random.permutation(n_samples) |
|
X = X[perm] |
|
y = y[perm] |
|
|
|
|
|
if epoch < warmup_epochs: |
|
lr = base_lr * (epoch + 1) / warmup_epochs |
|
else: |
|
|
|
cycle_length = 50 |
|
cycle = (epoch - warmup_epochs) // cycle_length |
|
t = (epoch - warmup_epochs) % cycle_length |
|
lr = base_lr * 0.5 * (1 + np.cos(t * np.pi / cycle_length)) |
|
|
|
if t == 0: |
|
lr = base_lr * (0.9 ** cycle) |
|
|
|
epoch_loss = 0.0 |
|
|
|
|
|
for i in range(n_batches): |
|
start_idx = i * batch_size |
|
end_idx = start_idx + batch_size |
|
X_batch = X[start_idx:end_idx] |
|
y_batch = y[start_idx:end_idx] |
|
|
|
try: |
|
|
|
network.params, loss = network._train_step( |
|
network.params, |
|
X_batch, |
|
y_batch |
|
) |
|
epoch_loss += loss |
|
except Exception as e: |
|
print(f"Error in batch {i}: {str(e)}") |
|
print(f"X_batch shape: {X_batch.shape}, y_batch shape: {y_batch.shape}") |
|
raise e |
|
|
|
|
|
predictions = network.predict(X) |
|
train_accuracy = np.mean((predictions > 0) == (y > 0)) |
|
|
|
|
|
if train_accuracy > best_accuracy: |
|
best_accuracy = train_accuracy |
|
best_params = {k: v.copy() for k, v in network.params.items()} |
|
patience_counter = 0 |
|
else: |
|
patience_counter += 1 |
|
|
|
|
|
print(f"Epoch {epoch}: Train Acc = {train_accuracy:.4f}, Loss = {epoch_loss/n_batches:.4f}, LR = {lr:.6f}") |
|
|
|
|
|
if patience_counter >= patience: |
|
print(f"Early stopping at epoch {epoch}") |
|
break |
|
|
|
except Exception as e: |
|
print(f"Error in epoch {epoch}: {str(e)}") |
|
raise e |
|
|
|
|
|
if best_params is not None: |
|
network.params = best_params |
|
print(f"\nRestored best parameters with accuracy: {best_accuracy:.4f}") |
|
|
|
return best_accuracy |
|
|
|
def plot_decision_boundary(network, X, y, save_path): |
|
"""Plot decision boundary with multiple views.""" |
|
fig, axes = plt.subplots(2, 2, figsize=(15, 15)) |
|
|
|
|
|
x_min, x_max = X[:, 0].min() - 0.1, X[:, 0].max() + 0.1 |
|
y_min, y_max = X[:, 1].min() - 0.1, X[:, 1].max() + 0.1 |
|
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100), |
|
np.linspace(y_min, y_max, 100)) |
|
|
|
|
|
r = np.sqrt(xx**2 + yy**2) |
|
theta = np.arctan2(yy, xx) |
|
theta = np.unwrap(theta) |
|
dr_dtheta = r / theta |
|
|
|
|
|
x_norm = xx.ravel() / np.max(np.abs(X[:, 0])) |
|
y_norm = yy.ravel() / np.max(np.abs(X[:, 1])) |
|
r_norm = r.ravel() / np.max(X[:, 2] * np.max(np.abs(X[:, 0]))) |
|
theta_norm = theta.ravel() / (6 * np.pi) |
|
dr_norm = dr_dtheta.ravel() / np.max(np.abs(X[:, 4])) |
|
|
|
|
|
X_mesh = jnp.array(np.column_stack([ |
|
x_norm, y_norm, r_norm, theta_norm, dr_norm |
|
]), dtype=jnp.float32) |
|
Z = network.predict(X_mesh) |
|
Z = Z.reshape(xx.shape) |
|
|
|
|
|
axes[0,0].contourf(xx, yy, Z, alpha=0.4, cmap='RdYlBu') |
|
axes[0,0].scatter(X[:, 0] * np.max(np.abs(X[:, 0])), |
|
X[:, 1] * np.max(np.abs(X[:, 1])), |
|
c=['red' if label == 1 else 'blue' for label in y], |
|
alpha=0.6) |
|
axes[0,0].set_title('Cartesian View') |
|
axes[0,0].grid(True) |
|
|
|
|
|
axes[0,1].scatter(X[:, 3] * 6 * np.pi, |
|
X[:, 2] * np.max(np.abs(X[:, 0])), |
|
c=['red' if label == 1 else 'blue' for label in y], |
|
alpha=0.6) |
|
axes[0,1].set_title('Polar View (θ vs r)') |
|
axes[0,1].grid(True) |
|
|
|
|
|
axes[1,0].scatter(X[:, 3] * 6 * np.pi, |
|
X[:, 4] * np.max(np.abs(X[:, 4])), |
|
c=['red' if label == 1 else 'blue' for label in y], |
|
alpha=0.6) |
|
axes[1,0].set_title('Spiral Tightness (dr/dθ vs θ)') |
|
axes[1,0].grid(True) |
|
|
|
|
|
axes[1,1].scatter(X[:, 4] * np.max(np.abs(X[:, 4])), |
|
X[:, 2] * np.max(np.abs(X[:, 0])), |
|
c=['red' if label == 1 else 'blue' for label in y], |
|
alpha=0.6) |
|
axes[1,1].set_title('Growth Rate (r vs dr/dθ)') |
|
axes[1,1].grid(True) |
|
|
|
plt.tight_layout() |
|
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
plt.close() |
|
|
|
def main(): |
|
"""Main training loop.""" |
|
print("\nTraining on Spiral dataset...") |
|
|
|
|
|
X, y = generate_spiral_dataset(n_points=1000, noise=0.1) |
|
|
|
|
|
X_train, X_val, y_train, y_val = train_test_split( |
|
X, y, test_size=0.2, random_state=42 |
|
) |
|
|
|
|
|
n_features = X.shape[1] |
|
neat = BackpropNEAT( |
|
n_inputs=n_features, |
|
n_outputs=1, |
|
n_hidden=32, |
|
population_size=5, |
|
learning_rate=0.01, |
|
beta=0.9 |
|
) |
|
|
|
|
|
n_epochs = 300 |
|
batch_size = 32 |
|
patience = 30 |
|
|
|
|
|
best_network = None |
|
best_val_acc = 0.0 |
|
|
|
for i, network in enumerate(neat.population): |
|
print(f"\nTraining network {i+1}/{len(neat.population)}...") |
|
|
|
|
|
train_accuracy = train_network( |
|
network, |
|
X_train, |
|
y_train, |
|
n_epochs=n_epochs, |
|
batch_size=batch_size, |
|
patience=patience |
|
) |
|
|
|
|
|
val_preds = network.predict(X_val) |
|
val_accuracy = np.mean((val_preds > 0) == (y_val > 0)) |
|
|
|
print(f"Network {i+1} - Train Acc: {train_accuracy:.4f}, Val Acc: {val_accuracy:.4f}") |
|
|
|
|
|
if val_accuracy > best_val_acc: |
|
best_val_acc = val_accuracy |
|
best_network = network |
|
|
|
|
|
if best_network is not None: |
|
plot_path = "spiral_decision_boundary.png" |
|
plot_decision_boundary(best_network, X, y, plot_path) |
|
print(f"\nDecision boundary plot saved to {plot_path}") |
|
|
|
if __name__ == "__main__": |
|
main() |