neat / backprop_train.py
eyad-silx's picture
Upload backprop_train.py with huggingface_hub
61162bb verified
raw
history blame
12.7 kB
"""Train BackpropNEAT on Spiral dataset."""
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
import os
import json
from datetime import datetime
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from neat.backprop_neat import BackpropNEAT
from neat.datasets import generate_spiral_dataset
from neat.network import Network
from neat.genome import Genome
class NetworkLogger:
"""Logger for tracking network evolution."""
def __init__(self, output_dir: str):
self.output_dir = output_dir
self.log_file = os.path.join(output_dir, "evolution_log.json")
self.history = []
def log_network(self, epoch: int, network, loss: float, accuracy: float):
"""Log network state."""
network_state = {
'epoch': epoch,
'loss': float(loss),
'accuracy': float(accuracy),
'n_nodes': network.genome.n_nodes,
'n_connections': len(network.genome.connections),
'complexity_score': self.calculate_complexity(network),
'structure': self.get_network_structure(network),
'timestamp': datetime.now().isoformat()
}
self.history.append(network_state)
# Save to file
with open(self.log_file, 'w') as f:
json.dump(self.history, f, indent=2)
def calculate_complexity(self, network):
"""Calculate network complexity score."""
n_nodes = network.genome.n_nodes
n_connections = len(network.genome.connections)
return n_nodes * 0.5 + n_connections
def get_network_structure(self, network):
"""Get detailed network structure."""
connections = []
for (src, dst), weight in network.genome.connections.items():
connections.append({
'source': int(src),
'target': int(dst),
'weight': float(weight)
})
return {
'input_size': network.genome.input_size,
'output_size': network.genome.output_size,
'hidden_nodes': network.genome.n_nodes - network.genome.input_size - network.genome.output_size,
'connections': connections
}
def plot_evolution(self, save_path: str):
"""Plot network evolution metrics."""
epochs = [log['epoch'] for log in self.history]
accuracies = [log['accuracy'] for log in self.history]
complexities = [log['complexity_score'] for log in self.history]
losses = [log['loss'] for log in self.history]
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(12, 12))
# Plot accuracy
ax1.plot(epochs, accuracies, 'b-', label='Accuracy')
ax1.set_ylabel('Accuracy')
ax1.set_title('Network Evolution')
ax1.grid(True)
ax1.legend()
# Plot complexity
ax2.plot(epochs, complexities, 'r-', label='Complexity Score')
ax2.set_ylabel('Complexity Score')
ax2.grid(True)
ax2.legend()
# Plot loss
ax3.plot(epochs, losses, 'g-', label='Loss')
ax3.set_ylabel('Loss')
ax3.set_xlabel('Epoch')
ax3.grid(True)
ax3.legend()
plt.tight_layout()
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
def visualize_dataset(X, y, network=None, title=None, save_path=None):
"""Visualize dataset with decision boundary."""
plt.figure(figsize=(10, 8))
if network is not None:
# Create mesh grid
x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100),
np.linspace(y_min, y_max, 100))
# Make predictions
X_mesh = jnp.array(np.c_[xx.ravel(), yy.ravel()], dtype=jnp.float32)
Z = network.predict(X_mesh)
Z = Z.reshape(xx.shape)
# Plot decision boundary
plt.contourf(xx, yy, Z, alpha=0.4, cmap='RdYlBu')
plt.scatter(X[y == 1, 0], X[y == 1, 1], c='red', label='Class 1')
plt.scatter(X[y == -1, 0], X[y == -1, 1], c='blue', label='Class -1')
plt.grid(True)
plt.legend()
plt.title(title or 'Dataset')
plt.xlabel('X1')
plt.ylabel('X2')
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Saved plot to {save_path}")
else:
plt.show()
plt.close()
def train_network(network, X, y, n_epochs=300, batch_size=32, patience=50):
"""Train a single network."""
print("Starting network training...")
print(f"Input shape: {X.shape}, Output shape: {y.shape}")
print(f"Network params: {network.params['weights'].keys()}")
n_samples = len(X)
n_batches = n_samples // batch_size
best_accuracy = 0.0
patience_counter = 0
best_params = None
# Convert to JAX arrays
print("Converting to JAX arrays...")
X = jnp.array(X, dtype=jnp.float32)
y = jnp.array(y, dtype=jnp.float32)
# Learning rate schedule
base_lr = 0.01
warmup_epochs = 5
print(f"\nTraining for {n_epochs} epochs with {n_batches} batches per epoch")
print(f"Batch size: {batch_size}, Patience: {patience}")
for epoch in range(n_epochs):
try:
# Shuffle data
perm = np.random.permutation(n_samples)
X = X[perm]
y = y[perm]
# Adjust learning rate with warmup and cosine decay
if epoch < warmup_epochs:
lr = base_lr * (epoch + 1) / warmup_epochs
else:
# Cosine decay with restarts
cycle_length = 50
cycle = (epoch - warmup_epochs) // cycle_length
t = (epoch - warmup_epochs) % cycle_length
lr = base_lr * 0.5 * (1 + np.cos(t * np.pi / cycle_length))
# Add small restart bump every cycle
if t == 0:
lr = base_lr * (0.9 ** cycle)
epoch_loss = 0.0
# Train on mini-batches
for i in range(n_batches):
start_idx = i * batch_size
end_idx = start_idx + batch_size
X_batch = X[start_idx:end_idx]
y_batch = y[start_idx:end_idx]
try:
# Update network parameters
network.params, loss = network._train_step(
network.params,
X_batch,
y_batch
)
epoch_loss += loss
except Exception as e:
print(f"Error in batch {i}: {str(e)}")
print(f"X_batch shape: {X_batch.shape}, y_batch shape: {y_batch.shape}")
raise e
# Compute training accuracy
predictions = network.predict(X)
train_accuracy = np.mean((predictions > 0) == (y > 0))
# Early stopping check
if train_accuracy > best_accuracy:
best_accuracy = train_accuracy
best_params = {k: v.copy() for k, v in network.params.items()}
patience_counter = 0
else:
patience_counter += 1
# Print progress every epoch
print(f"Epoch {epoch}: Train Acc = {train_accuracy:.4f}, Loss = {epoch_loss/n_batches:.4f}, LR = {lr:.6f}")
# Early stopping
if patience_counter >= patience:
print(f"Early stopping at epoch {epoch}")
break
except Exception as e:
print(f"Error in epoch {epoch}: {str(e)}")
raise e
# Restore best parameters
if best_params is not None:
network.params = best_params
print(f"\nRestored best parameters with accuracy: {best_accuracy:.4f}")
return best_accuracy
def plot_decision_boundary(network, X, y, save_path):
"""Plot decision boundary with multiple views."""
fig, axes = plt.subplots(2, 2, figsize=(15, 15))
# Cartesian View
x_min, x_max = X[:, 0].min() - 0.1, X[:, 0].max() + 0.1
y_min, y_max = X[:, 1].min() - 0.1, X[:, 1].max() + 0.1
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100),
np.linspace(y_min, y_max, 100))
# Create all features for prediction
r = np.sqrt(xx**2 + yy**2)
theta = np.arctan2(yy, xx)
theta = np.unwrap(theta)
dr_dtheta = r / theta
# Normalize features
x_norm = xx.ravel() / np.max(np.abs(X[:, 0]))
y_norm = yy.ravel() / np.max(np.abs(X[:, 1]))
r_norm = r.ravel() / np.max(X[:, 2] * np.max(np.abs(X[:, 0])))
theta_norm = theta.ravel() / (6 * np.pi)
dr_norm = dr_dtheta.ravel() / np.max(np.abs(X[:, 4]))
# Make predictions
X_mesh = jnp.array(np.column_stack([
x_norm, y_norm, r_norm, theta_norm, dr_norm
]), dtype=jnp.float32)
Z = network.predict(X_mesh)
Z = Z.reshape(xx.shape)
# Plot Cartesian view
axes[0,0].contourf(xx, yy, Z, alpha=0.4, cmap='RdYlBu')
axes[0,0].scatter(X[:, 0] * np.max(np.abs(X[:, 0])),
X[:, 1] * np.max(np.abs(X[:, 1])),
c=['red' if label == 1 else 'blue' for label in y],
alpha=0.6)
axes[0,0].set_title('Cartesian View')
axes[0,0].grid(True)
# Plot Polar view (θ vs r)
axes[0,1].scatter(X[:, 3] * 6 * np.pi, # Denormalize theta
X[:, 2] * np.max(np.abs(X[:, 0])), # Denormalize radius
c=['red' if label == 1 else 'blue' for label in y],
alpha=0.6)
axes[0,1].set_title('Polar View (θ vs r)')
axes[0,1].grid(True)
# Plot dr/dθ vs θ
axes[1,0].scatter(X[:, 3] * 6 * np.pi, # theta
X[:, 4] * np.max(np.abs(X[:, 4])), # dr/dtheta
c=['red' if label == 1 else 'blue' for label in y],
alpha=0.6)
axes[1,0].set_title('Spiral Tightness (dr/dθ vs θ)')
axes[1,0].grid(True)
# Plot r vs dr/dθ
axes[1,1].scatter(X[:, 4] * np.max(np.abs(X[:, 4])), # dr/dtheta
X[:, 2] * np.max(np.abs(X[:, 0])), # radius
c=['red' if label == 1 else 'blue' for label in y],
alpha=0.6)
axes[1,1].set_title('Growth Rate (r vs dr/dθ)')
axes[1,1].grid(True)
plt.tight_layout()
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
def main():
"""Main training loop."""
print("\nTraining on Spiral dataset...")
# Generate spiral dataset
X, y = generate_spiral_dataset(n_points=1000, noise=0.1)
# Split data
X_train, X_val, y_train, y_val = train_test_split(
X, y, test_size=0.2, random_state=42
)
# Initialize BackpropNEAT with smaller network
n_features = X.shape[1]
neat = BackpropNEAT(
n_inputs=n_features,
n_outputs=1,
n_hidden=32, # Reduced hidden layer size
population_size=5,
learning_rate=0.01,
beta=0.9
)
# Training parameters
n_epochs = 300
batch_size = 32
patience = 30 # Reduced patience
# Train each network in the population
best_network = None
best_val_acc = 0.0
for i, network in enumerate(neat.population):
print(f"\nTraining network {i+1}/{len(neat.population)}...")
# Train network
train_accuracy = train_network(
network,
X_train,
y_train,
n_epochs=n_epochs,
batch_size=batch_size,
patience=patience
)
# Evaluate on validation set
val_preds = network.predict(X_val)
val_accuracy = np.mean((val_preds > 0) == (y_val > 0))
print(f"Network {i+1} - Train Acc: {train_accuracy:.4f}, Val Acc: {val_accuracy:.4f}")
# Update best network
if val_accuracy > best_val_acc:
best_val_acc = val_accuracy
best_network = network
# Plot decision boundary for best network
if best_network is not None:
plot_path = "spiral_decision_boundary.png"
plot_decision_boundary(best_network, X, y, plot_path)
print(f"\nDecision boundary plot saved to {plot_path}")
if __name__ == "__main__":
main()