neat / neat\backprop_neat.py
eyad-silx's picture
Upload neat\backprop_neat.py with huggingface_hub
f0a4dca verified
raw
history blame
11.7 kB
"""BackpropNEAT implementation."""
import jax
import jax.numpy as jnp
import numpy as np
from typing import Dict, List, Tuple
from .network import Network
from .genome import Genome
class BackpropNEAT:
"""Backpropagation-based NEAT implementation."""
def __init__(self, population_size=5, n_inputs=2, n_outputs=1, n_hidden=64,
learning_rate=0.01, beta=0.9):
"""Initialize BackpropNEAT."""
self.population_size = population_size
self.n_inputs = n_inputs
self.n_outputs = n_outputs
self.n_hidden = n_hidden
self.learning_rate = learning_rate
self.beta = beta
# Initialize population
self.population = []
self.momentum_buffers = []
for _ in range(population_size):
# Create genome with skip connections
genome = Genome(n_inputs, n_outputs, n_hidden)
genome.add_layer_connections() # Add standard layer connections
genome.add_skip_connections(0.3) # Add skip connections with 30% probability
# Create network from genome
network = Network(genome)
self.population.append(network)
# Initialize momentum buffer for this network
momentum = {
'weights': {k: jnp.zeros_like(w) for k, w in network.params['weights'].items()},
'biases': jnp.zeros_like(network.params['biases']),
'gamma': jnp.zeros_like(network.params['gamma']),
'beta': jnp.zeros_like(network.params['beta'])
}
self.momentum_buffers.append(momentum)
# Create train step function
self._train_step = self._make_train_step()
# Bind train step to each network
for i, network in enumerate(self.population):
network.population_idx = i
# Create a bound method for each network
network._train_step = lambda p, x, y, idx=i: self._train_step(self, p, x, y, idx)
def forward(self, params, x):
"""Forward pass through network."""
return self.population[0].forward(params, x)
def _make_train_step(self):
"""Create training step function."""
# Constants for numerical stability
eps = 1e-7
min_lr = 1e-6
max_lr = 1e-2
def loss_fn(params, x, y):
"""Compute loss for parameters."""
logits = self.forward(params, x)
# Binary cross entropy loss with label smoothing
alpha = 0.1 # Label smoothing factor
# Smooth labels
y_smooth = (1 - alpha) * y + alpha * 0.5
# Convert logits to probabilities
probs = 0.5 * (logits + 1) # Map from [-1,1] to [0,1]
probs = jnp.clip(probs, eps, 1 - eps)
# Compute loss with label smoothing
bce_loss = -jnp.mean(
0.5 * (1 + y_smooth) * jnp.log(probs) +
0.5 * (1 - y_smooth) * jnp.log(1 - probs)
)
# L2 regularization with very small weight
l2_reg = sum(jnp.sum(w ** 2) for w in params['weights'].values())
return bce_loss + 0.000001 * l2_reg
@jax.jit
def compute_updates(params, x, y):
"""Compute gradients and loss."""
loss_value, grads = jax.value_and_grad(loss_fn)(params, x, y)
return grads, loss_value
def train_step(self, params, x, y, network_idx):
"""Perform single training step with momentum."""
# Compute gradients
grads, loss_value = compute_updates(params, x, y)
# Get momentum buffer for this network
momentum = self.momentum_buffers[network_idx]
# Gradient norm for adaptive learning rate
grad_norm = jnp.sqrt(
sum(jnp.sum(g ** 2) for g in grads['weights'].values()) +
jnp.sum(grads['biases'] ** 2) +
jnp.sum(grads['gamma'] ** 2) +
jnp.sum(grads['beta'] ** 2) +
eps # Add eps for numerical stability
)
# Compute adaptive learning rate
if grad_norm > 1.0:
effective_lr = self.learning_rate / grad_norm
else:
effective_lr = self.learning_rate * (1.0 + jnp.log(grad_norm + eps))
# Clip learning rate to reasonable range
effective_lr = jnp.clip(effective_lr, min_lr, max_lr)
# Update weights momentum with adaptive learning rate
new_weights = {}
for k in params['weights'].keys():
grad = grads['weights'][k]
# Update momentum with gradient clipping
momentum['weights'][k] = (
self.beta * momentum['weights'][k] +
(1 - self.beta) * jnp.clip(grad, -1.0, 1.0)
)
# Apply update with weight decay
weight_decay = 0.0001 * params['weights'][k]
new_weights[k] = params['weights'][k] - effective_lr * (
momentum['weights'][k] + weight_decay
)
# Update biases momentum
momentum['biases'] = (
self.beta * momentum['biases'] +
(1 - self.beta) * jnp.clip(grads['biases'], -1.0, 1.0)
)
new_biases = params['biases'] - effective_lr * momentum['biases']
# Update layer norm parameters with smaller learning rate
ln_lr = 0.1 * effective_lr # Slower updates for stability
# Gamma (scale)
momentum['gamma'] = (
self.beta * momentum['gamma'] +
(1 - self.beta) * jnp.clip(grads['gamma'], -0.1, 0.1)
)
new_gamma = params['gamma'] - ln_lr * momentum['gamma']
new_gamma = jnp.clip(new_gamma, 0.1, 10.0) # Prevent collapse
# Beta (shift)
momentum['beta'] = (
self.beta * momentum['beta'] +
(1 - self.beta) * jnp.clip(grads['beta'], -0.1, 0.1)
)
new_beta = params['beta'] - ln_lr * momentum['beta']
return {
'weights': new_weights,
'biases': new_biases,
'gamma': new_gamma,
'beta': new_beta
}, loss_value
return train_step
def _mutate_genome(self, genome: Genome) -> Genome:
"""Mutate genome architecture."""
new_genome = genome.copy()
# Mutate weights and biases
for key in list(new_genome.params['weights'].keys()):
if np.random.random() < 0.1:
new_genome.params['weights'][key] += np.random.normal(0, 0.2)
for key in list(new_genome.params['biases'].keys()):
if np.random.random() < 0.1:
new_genome.params['biases'][key] += np.random.normal(0, 0.2)
return new_genome
def _select_parent(self, fitnesses: List[float]) -> int:
"""Select parent using tournament selection."""
# Tournament selection
tournament_size = 3
best_idx = np.random.randint(len(fitnesses))
best_fitness = fitnesses[best_idx]
for _ in range(tournament_size - 1):
idx = np.random.randint(len(fitnesses))
if fitnesses[idx] > best_fitness:
best_idx = idx
best_fitness = fitnesses[idx]
return best_idx
def _compute_fitness(self, network: Network, x: jnp.ndarray, y: jnp.ndarray,
n_epochs: int = 100, batch_size: int = 32) -> float:
"""Compute fitness of network."""
n_samples = x.shape[0]
best_loss = float('inf')
best_accuracy = 0.0
# Initial prediction
initial_pred = network.predict(x)
initial_acc = float(jnp.mean((initial_pred == y)))
# Train network
no_improve = 0
for epoch in range(n_epochs):
# Shuffle data
perm = np.random.permutation(n_samples)
x_shuffled = x[perm]
y_shuffled = y[perm]
# Train in batches
epoch_losses = []
for i in range(0, n_samples, batch_size):
batch_x = x_shuffled[i:min(i + batch_size, n_samples)]
batch_y = y_shuffled[i:min(i + batch_size, n_samples)]
# Train step
network.params, loss = network._train_step(network.params, batch_x, batch_y)
epoch_losses.append(float(loss))
# Update best loss
avg_loss = float(np.mean(epoch_losses))
if avg_loss < best_loss:
best_loss = avg_loss
no_improve = 0
else:
no_improve += 1
# Compute accuracy
predictions = network.predict(x)
accuracy = float(jnp.mean((predictions == y)))
best_accuracy = max(best_accuracy, accuracy)
# Print progress every 10 epochs
if epoch % 10 == 0:
print(f"Epoch {epoch}: Loss = {avg_loss:.4f}, Accuracy = {accuracy:.4f}")
# Early stopping if good accuracy or no improvement
if accuracy > 0.95 or no_improve >= 10:
print(f"Early stopping at epoch {epoch}")
print(f"Final accuracy: {accuracy:.4f}")
break
# Print improvement
print(f"Network improved from {initial_acc:.4f} to {best_accuracy:.4f}")
# Fitness based on accuracy
fitness = best_accuracy
return float(fitness)
def evolve(self, x: jnp.ndarray, y: jnp.ndarray, n_generations: int = 50) -> Network:
"""Evolve network architectures."""
for generation in range(n_generations):
print(f"\nGeneration {generation}")
# Evaluate current population
fitnesses = []
for network in self.population:
fitness = self._compute_fitness(network, x, y)
fitnesses.append(fitness)
# Update best network
if fitness > self.best_fitness:
self.best_fitness = fitness
self.best_network = Network(network.genome.copy())
print(f"New best fitness: {fitness:.4f}")
# Create new population through selection and mutation
new_population = []
# Keep best network (elitism)
best_idx = np.argmax(fitnesses)
new_population.append(Network(self.population[best_idx].genome.copy()))
# Create rest of population
while len(new_population) < self.population_size:
# Select parent
parent_idx = self._select_parent(fitnesses)
parent = self.population[parent_idx].genome
# Create child through mutation
child_genome = self._mutate_genome(parent)
new_population.append(Network(child_genome))
self.population = new_population
return self.best_network