"""BackpropNEAT implementation."""
import jax
import jax.numpy as jnp
import numpy as np
from typing import Dict, List, Tuple
from .network import Network
from .genome import Genome
class BackpropNEAT:
"""Backpropagation-based NEAT implementation."""
def __init__(self, population_size=5, n_inputs=2, n_outputs=1, n_hidden=64,
learning_rate=0.01, beta=0.9):
"""Initialize BackpropNEAT."""
self.population_size = population_size
self.n_inputs = n_inputs
self.n_outputs = n_outputs
self.n_hidden = n_hidden
self.learning_rate = learning_rate
self.beta = beta
# Initialize population
self.population = []
self.momentum_buffers = []
for _ in range(population_size):
# Create genome with skip connections
genome = Genome(n_inputs, n_outputs, n_hidden)
genome.add_layer_connections() # Add standard layer connections
genome.add_skip_connections(0.3) # Add skip connections with 30% probability
# Create network from genome
network = Network(genome)
# Initialize momentum buffer for this network
momentum = {
'weights': {k: jnp.zeros_like(w) for k, w in network.params['weights'].items()},
'biases': jnp.zeros_like(network.params['biases']),
'gamma': jnp.zeros_like(network.params['gamma']),
'beta': jnp.zeros_like(network.params['beta'])
# Create train step function
self._train_step = self._make_train_step()
# Bind train step to each network
for i, network in enumerate(self.population):
network.population_idx = i
# Create a bound method for each network
network._train_step = lambda p, x, y, idx=i: self._train_step(self, p, x, y, idx)
def forward(self, params, x):
"""Forward pass through network."""
return self.population[0].forward(params, x)
def _make_train_step(self):
"""Create training step function."""
# Constants for numerical stability
eps = 1e-7
min_lr = 1e-6
max_lr = 1e-2
def loss_fn(params, x, y):
"""Compute loss for parameters."""
logits = self.forward(params, x)
# Binary cross entropy loss with label smoothing
alpha = 0.1 # Label smoothing factor
# Smooth labels
y_smooth = (1 - alpha) * y + alpha * 0.5
# Convert logits to probabilities
probs = 0.5 * (logits + 1) # Map from [-1,1] to [0,1]
probs = jnp.clip(probs, eps, 1 - eps)
# Compute loss with label smoothing
bce_loss = -jnp.mean(
0.5 * (1 + y_smooth) * jnp.log(probs) +
0.5 * (1 - y_smooth) * jnp.log(1 - probs)
# L2 regularization with very small weight
l2_reg = sum(jnp.sum(w ** 2) for w in params['weights'].values())
return bce_loss + 0.000001 * l2_reg
def compute_updates(params, x, y):
"""Compute gradients and loss."""
loss_value, grads = jax.value_and_grad(loss_fn)(params, x, y)
return grads, loss_value
def train_step(self, params, x, y, network_idx):
"""Perform single training step with momentum."""
# Compute gradients
grads, loss_value = compute_updates(params, x, y)
# Get momentum buffer for this network
momentum = self.momentum_buffers[network_idx]
# Gradient norm for adaptive learning rate
grad_norm = jnp.sqrt(
sum(jnp.sum(g ** 2) for g in grads['weights'].values()) +
jnp.sum(grads['biases'] ** 2) +
jnp.sum(grads['gamma'] ** 2) +
jnp.sum(grads['beta'] ** 2) +
eps # Add eps for numerical stability
# Compute adaptive learning rate
if grad_norm > 1.0:
effective_lr = self.learning_rate / grad_norm
effective_lr = self.learning_rate * (1.0 + jnp.log(grad_norm + eps))
# Clip learning rate to reasonable range
effective_lr = jnp.clip(effective_lr, min_lr, max_lr)
# Update weights momentum with adaptive learning rate
new_weights = {}
for k in params['weights'].keys():
grad = grads['weights'][k]
# Update momentum with gradient clipping
momentum['weights'][k] = (
self.beta * momentum['weights'][k] +
(1 - self.beta) * jnp.clip(grad, -1.0, 1.0)
# Apply update with weight decay
weight_decay = 0.0001 * params['weights'][k]
new_weights[k] = params['weights'][k] - effective_lr * (
momentum['weights'][k] + weight_decay
# Update biases momentum
momentum['biases'] = (
self.beta * momentum['biases'] +
(1 - self.beta) * jnp.clip(grads['biases'], -1.0, 1.0)
new_biases = params['biases'] - effective_lr * momentum['biases']
# Update layer norm parameters with smaller learning rate
ln_lr = 0.1 * effective_lr # Slower updates for stability
# Gamma (scale)
momentum['gamma'] = (
self.beta * momentum['gamma'] +
(1 - self.beta) * jnp.clip(grads['gamma'], -0.1, 0.1)
new_gamma = params['gamma'] - ln_lr * momentum['gamma']
new_gamma = jnp.clip(new_gamma, 0.1, 10.0) # Prevent collapse
# Beta (shift)
momentum['beta'] = (
self.beta * momentum['beta'] +
(1 - self.beta) * jnp.clip(grads['beta'], -0.1, 0.1)
new_beta = params['beta'] - ln_lr * momentum['beta']
return {
'weights': new_weights,
'biases': new_biases,
'gamma': new_gamma,
'beta': new_beta
}, loss_value
return train_step
def _mutate_genome(self, genome: Genome) -> Genome:
"""Mutate genome architecture."""
new_genome = genome.copy()
# Mutate weights and biases
for key in list(new_genome.params['weights'].keys()):
if np.random.random() < 0.1:
new_genome.params['weights'][key] += np.random.normal(0, 0.2)
for key in list(new_genome.params['biases'].keys()):
if np.random.random() < 0.1:
new_genome.params['biases'][key] += np.random.normal(0, 0.2)
return new_genome
def _select_parent(self, fitnesses: List[float]) -> int:
"""Select parent using tournament selection."""
# Tournament selection
tournament_size = 3
best_idx = np.random.randint(len(fitnesses))
best_fitness = fitnesses[best_idx]
for _ in range(tournament_size - 1):
idx = np.random.randint(len(fitnesses))
if fitnesses[idx] > best_fitness:
best_idx = idx
best_fitness = fitnesses[idx]
return best_idx
def _compute_fitness(self, network: Network, x: jnp.ndarray, y: jnp.ndarray,
n_epochs: int = 100, batch_size: int = 32) -> float:
"""Compute fitness of network."""
n_samples = x.shape[0]
best_loss = float('inf')
best_accuracy = 0.0
# Initial prediction
initial_pred = network.predict(x)
initial_acc = float(jnp.mean((initial_pred == y)))
# Train network
no_improve = 0
for epoch in range(n_epochs):
# Shuffle data
perm = np.random.permutation(n_samples)
x_shuffled = x[perm]
y_shuffled = y[perm]
# Train in batches
epoch_losses = []
for i in range(0, n_samples, batch_size):
batch_x = x_shuffled[i:min(i + batch_size, n_samples)]
batch_y = y_shuffled[i:min(i + batch_size, n_samples)]
# Train step
network.params, loss = network._train_step(network.params, batch_x, batch_y)
# Update best loss
avg_loss = float(np.mean(epoch_losses))
if avg_loss < best_loss:
best_loss = avg_loss
no_improve = 0
no_improve += 1
# Compute accuracy
predictions = network.predict(x)
accuracy = float(jnp.mean((predictions == y)))
best_accuracy = max(best_accuracy, accuracy)
# Print progress every 10 epochs
if epoch % 10 == 0:
print(f"Epoch {epoch}: Loss = {avg_loss:.4f}, Accuracy = {accuracy:.4f}")
# Early stopping if good accuracy or no improvement
if accuracy > 0.95 or no_improve >= 10:
print(f"Early stopping at epoch {epoch}")
print(f"Final accuracy: {accuracy:.4f}")
# Print improvement
print(f"Network improved from {initial_acc:.4f} to {best_accuracy:.4f}")
# Fitness based on accuracy
fitness = best_accuracy
return float(fitness)
def evolve(self, x: jnp.ndarray, y: jnp.ndarray, n_generations: int = 50) -> Network:
"""Evolve network architectures."""
for generation in range(n_generations):
print(f"\nGeneration {generation}")
# Evaluate current population
fitnesses = []
for network in self.population:
fitness = self._compute_fitness(network, x, y)
# Update best network
if fitness > self.best_fitness:
self.best_fitness = fitness
self.best_network = Network(network.genome.copy())
print(f"New best fitness: {fitness:.4f}")
# Create new population through selection and mutation
new_population = []
# Keep best network (elitism)
best_idx = np.argmax(fitnesses)
# Create rest of population
while len(new_population) < self.population_size:
# Select parent
parent_idx = self._select_parent(fitnesses)
parent = self.population[parent_idx].genome
# Create child through mutation
child_genome = self._mutate_genome(parent)
self.population = new_population
return self.best_network