|
"""BackpropNEAT implementation.""" |
|
|
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
from typing import Dict, List, Tuple |
|
|
|
from .network import Network |
|
from .genome import Genome |
|
|
|
class BackpropNEAT: |
|
"""Backpropagation-based NEAT implementation.""" |
|
|
|
def __init__(self, population_size=5, n_inputs=2, n_outputs=1, n_hidden=64, |
|
learning_rate=0.01, beta=0.9): |
|
"""Initialize BackpropNEAT.""" |
|
self.population_size = population_size |
|
self.n_inputs = n_inputs |
|
self.n_outputs = n_outputs |
|
self.n_hidden = n_hidden |
|
self.learning_rate = learning_rate |
|
self.beta = beta |
|
|
|
|
|
self.population = [] |
|
self.momentum_buffers = [] |
|
|
|
for _ in range(population_size): |
|
|
|
genome = Genome(n_inputs, n_outputs, n_hidden) |
|
genome.add_layer_connections() |
|
genome.add_skip_connections(0.3) |
|
|
|
|
|
network = Network(genome) |
|
self.population.append(network) |
|
|
|
|
|
momentum = { |
|
'weights': {k: jnp.zeros_like(w) for k, w in network.params['weights'].items()}, |
|
'biases': jnp.zeros_like(network.params['biases']), |
|
'gamma': jnp.zeros_like(network.params['gamma']), |
|
'beta': jnp.zeros_like(network.params['beta']) |
|
} |
|
self.momentum_buffers.append(momentum) |
|
|
|
|
|
self._train_step = self._make_train_step() |
|
|
|
|
|
for i, network in enumerate(self.population): |
|
network.population_idx = i |
|
|
|
network._train_step = lambda p, x, y, idx=i: self._train_step(self, p, x, y, idx) |
|
|
|
def forward(self, params, x): |
|
"""Forward pass through network.""" |
|
return self.population[0].forward(params, x) |
|
|
|
def _make_train_step(self): |
|
"""Create training step function.""" |
|
|
|
eps = 1e-7 |
|
min_lr = 1e-6 |
|
max_lr = 1e-2 |
|
|
|
def loss_fn(params, x, y): |
|
"""Compute loss for parameters.""" |
|
logits = self.forward(params, x) |
|
|
|
|
|
alpha = 0.1 |
|
|
|
|
|
y_smooth = (1 - alpha) * y + alpha * 0.5 |
|
|
|
|
|
probs = 0.5 * (logits + 1) |
|
probs = jnp.clip(probs, eps, 1 - eps) |
|
|
|
|
|
bce_loss = -jnp.mean( |
|
0.5 * (1 + y_smooth) * jnp.log(probs) + |
|
0.5 * (1 - y_smooth) * jnp.log(1 - probs) |
|
) |
|
|
|
|
|
l2_reg = sum(jnp.sum(w ** 2) for w in params['weights'].values()) |
|
return bce_loss + 0.000001 * l2_reg |
|
|
|
@jax.jit |
|
def compute_updates(params, x, y): |
|
"""Compute gradients and loss.""" |
|
loss_value, grads = jax.value_and_grad(loss_fn)(params, x, y) |
|
return grads, loss_value |
|
|
|
def train_step(self, params, x, y, network_idx): |
|
"""Perform single training step with momentum.""" |
|
|
|
grads, loss_value = compute_updates(params, x, y) |
|
|
|
|
|
momentum = self.momentum_buffers[network_idx] |
|
|
|
|
|
grad_norm = jnp.sqrt( |
|
sum(jnp.sum(g ** 2) for g in grads['weights'].values()) + |
|
jnp.sum(grads['biases'] ** 2) + |
|
jnp.sum(grads['gamma'] ** 2) + |
|
jnp.sum(grads['beta'] ** 2) + |
|
eps |
|
) |
|
|
|
|
|
if grad_norm > 1.0: |
|
effective_lr = self.learning_rate / grad_norm |
|
else: |
|
effective_lr = self.learning_rate * (1.0 + jnp.log(grad_norm + eps)) |
|
|
|
|
|
effective_lr = jnp.clip(effective_lr, min_lr, max_lr) |
|
|
|
|
|
new_weights = {} |
|
for k in params['weights'].keys(): |
|
grad = grads['weights'][k] |
|
|
|
|
|
momentum['weights'][k] = ( |
|
self.beta * momentum['weights'][k] + |
|
(1 - self.beta) * jnp.clip(grad, -1.0, 1.0) |
|
) |
|
|
|
|
|
weight_decay = 0.0001 * params['weights'][k] |
|
new_weights[k] = params['weights'][k] - effective_lr * ( |
|
momentum['weights'][k] + weight_decay |
|
) |
|
|
|
|
|
momentum['biases'] = ( |
|
self.beta * momentum['biases'] + |
|
(1 - self.beta) * jnp.clip(grads['biases'], -1.0, 1.0) |
|
) |
|
new_biases = params['biases'] - effective_lr * momentum['biases'] |
|
|
|
|
|
ln_lr = 0.1 * effective_lr |
|
|
|
|
|
momentum['gamma'] = ( |
|
self.beta * momentum['gamma'] + |
|
(1 - self.beta) * jnp.clip(grads['gamma'], -0.1, 0.1) |
|
) |
|
new_gamma = params['gamma'] - ln_lr * momentum['gamma'] |
|
new_gamma = jnp.clip(new_gamma, 0.1, 10.0) |
|
|
|
|
|
momentum['beta'] = ( |
|
self.beta * momentum['beta'] + |
|
(1 - self.beta) * jnp.clip(grads['beta'], -0.1, 0.1) |
|
) |
|
new_beta = params['beta'] - ln_lr * momentum['beta'] |
|
|
|
return { |
|
'weights': new_weights, |
|
'biases': new_biases, |
|
'gamma': new_gamma, |
|
'beta': new_beta |
|
}, loss_value |
|
|
|
return train_step |
|
|
|
def _mutate_genome(self, genome: Genome) -> Genome: |
|
"""Mutate genome architecture.""" |
|
new_genome = genome.copy() |
|
|
|
|
|
for key in list(new_genome.params['weights'].keys()): |
|
if np.random.random() < 0.1: |
|
new_genome.params['weights'][key] += np.random.normal(0, 0.2) |
|
|
|
for key in list(new_genome.params['biases'].keys()): |
|
if np.random.random() < 0.1: |
|
new_genome.params['biases'][key] += np.random.normal(0, 0.2) |
|
|
|
return new_genome |
|
|
|
def _select_parent(self, fitnesses: List[float]) -> int: |
|
"""Select parent using tournament selection.""" |
|
|
|
tournament_size = 3 |
|
best_idx = np.random.randint(len(fitnesses)) |
|
best_fitness = fitnesses[best_idx] |
|
|
|
for _ in range(tournament_size - 1): |
|
idx = np.random.randint(len(fitnesses)) |
|
if fitnesses[idx] > best_fitness: |
|
best_idx = idx |
|
best_fitness = fitnesses[idx] |
|
|
|
return best_idx |
|
|
|
def _compute_fitness(self, network: Network, x: jnp.ndarray, y: jnp.ndarray, |
|
n_epochs: int = 100, batch_size: int = 32) -> float: |
|
"""Compute fitness of network.""" |
|
n_samples = x.shape[0] |
|
best_loss = float('inf') |
|
best_accuracy = 0.0 |
|
|
|
|
|
initial_pred = network.predict(x) |
|
initial_acc = float(jnp.mean((initial_pred == y))) |
|
|
|
|
|
no_improve = 0 |
|
for epoch in range(n_epochs): |
|
|
|
perm = np.random.permutation(n_samples) |
|
x_shuffled = x[perm] |
|
y_shuffled = y[perm] |
|
|
|
|
|
epoch_losses = [] |
|
for i in range(0, n_samples, batch_size): |
|
batch_x = x_shuffled[i:min(i + batch_size, n_samples)] |
|
batch_y = y_shuffled[i:min(i + batch_size, n_samples)] |
|
|
|
|
|
network.params, loss = network._train_step(network.params, batch_x, batch_y) |
|
epoch_losses.append(float(loss)) |
|
|
|
|
|
avg_loss = float(np.mean(epoch_losses)) |
|
if avg_loss < best_loss: |
|
best_loss = avg_loss |
|
no_improve = 0 |
|
else: |
|
no_improve += 1 |
|
|
|
|
|
predictions = network.predict(x) |
|
accuracy = float(jnp.mean((predictions == y))) |
|
best_accuracy = max(best_accuracy, accuracy) |
|
|
|
|
|
if epoch % 10 == 0: |
|
print(f"Epoch {epoch}: Loss = {avg_loss:.4f}, Accuracy = {accuracy:.4f}") |
|
|
|
|
|
if accuracy > 0.95 or no_improve >= 10: |
|
print(f"Early stopping at epoch {epoch}") |
|
print(f"Final accuracy: {accuracy:.4f}") |
|
break |
|
|
|
|
|
print(f"Network improved from {initial_acc:.4f} to {best_accuracy:.4f}") |
|
|
|
|
|
fitness = best_accuracy |
|
|
|
return float(fitness) |
|
|
|
def evolve(self, x: jnp.ndarray, y: jnp.ndarray, n_generations: int = 50) -> Network: |
|
"""Evolve network architectures.""" |
|
for generation in range(n_generations): |
|
print(f"\nGeneration {generation}") |
|
|
|
|
|
fitnesses = [] |
|
for network in self.population: |
|
fitness = self._compute_fitness(network, x, y) |
|
fitnesses.append(fitness) |
|
|
|
|
|
if fitness > self.best_fitness: |
|
self.best_fitness = fitness |
|
self.best_network = Network(network.genome.copy()) |
|
print(f"New best fitness: {fitness:.4f}") |
|
|
|
|
|
new_population = [] |
|
|
|
|
|
best_idx = np.argmax(fitnesses) |
|
new_population.append(Network(self.population[best_idx].genome.copy())) |
|
|
|
|
|
while len(new_population) < self.population_size: |
|
|
|
parent_idx = self._select_parent(fitnesses) |
|
parent = self.population[parent_idx].genome |
|
|
|
|
|
child_genome = self._mutate_genome(parent) |
|
new_population.append(Network(child_genome)) |
|
|
|
self.population = new_population |
|
|
|
return self.best_network |
|
|