"""BackpropNEAT implementation.""" import jax import jax.numpy as jnp import numpy as np from typing import Dict, List, Tuple from .network import Network from .genome import Genome class BackpropNEAT: """Backpropagation-based NEAT implementation.""" def __init__(self, population_size=5, n_inputs=2, n_outputs=1, n_hidden=64, learning_rate=0.01, beta=0.9): """Initialize BackpropNEAT.""" self.population_size = population_size self.n_inputs = n_inputs self.n_outputs = n_outputs self.n_hidden = n_hidden self.learning_rate = learning_rate self.beta = beta # Initialize population self.population = [] self.momentum_buffers = [] for _ in range(population_size): # Create genome with skip connections genome = Genome(n_inputs, n_outputs, n_hidden) genome.add_layer_connections() # Add standard layer connections genome.add_skip_connections(0.3) # Add skip connections with 30% probability # Create network from genome network = Network(genome) self.population.append(network) # Initialize momentum buffer for this network momentum = { 'weights': {k: jnp.zeros_like(w) for k, w in network.params['weights'].items()}, 'biases': jnp.zeros_like(network.params['biases']), 'gamma': jnp.zeros_like(network.params['gamma']), 'beta': jnp.zeros_like(network.params['beta']) } self.momentum_buffers.append(momentum) # Create train step function self._train_step = self._make_train_step() # Bind train step to each network for i, network in enumerate(self.population): network.population_idx = i # Create a bound method for each network network._train_step = lambda p, x, y, idx=i: self._train_step(self, p, x, y, idx) def forward(self, params, x): """Forward pass through network.""" return self.population[0].forward(params, x) def _make_train_step(self): """Create training step function.""" # Constants for numerical stability eps = 1e-7 min_lr = 1e-6 max_lr = 1e-2 def loss_fn(params, x, y): """Compute loss for parameters.""" logits = self.forward(params, x) # Binary cross entropy loss with label smoothing alpha = 0.1 # Label smoothing factor # Smooth labels y_smooth = (1 - alpha) * y + alpha * 0.5 # Convert logits to probabilities probs = 0.5 * (logits + 1) # Map from [-1,1] to [0,1] probs = jnp.clip(probs, eps, 1 - eps) # Compute loss with label smoothing bce_loss = -jnp.mean( 0.5 * (1 + y_smooth) * jnp.log(probs) + 0.5 * (1 - y_smooth) * jnp.log(1 - probs) ) # L2 regularization with very small weight l2_reg = sum(jnp.sum(w ** 2) for w in params['weights'].values()) return bce_loss + 0.000001 * l2_reg @jax.jit def compute_updates(params, x, y): """Compute gradients and loss.""" loss_value, grads = jax.value_and_grad(loss_fn)(params, x, y) return grads, loss_value def train_step(self, params, x, y, network_idx): """Perform single training step with momentum.""" # Compute gradients grads, loss_value = compute_updates(params, x, y) # Get momentum buffer for this network momentum = self.momentum_buffers[network_idx] # Gradient norm for adaptive learning rate grad_norm = jnp.sqrt( sum(jnp.sum(g ** 2) for g in grads['weights'].values()) + jnp.sum(grads['biases'] ** 2) + jnp.sum(grads['gamma'] ** 2) + jnp.sum(grads['beta'] ** 2) + eps # Add eps for numerical stability ) # Compute adaptive learning rate if grad_norm > 1.0: effective_lr = self.learning_rate / grad_norm else: effective_lr = self.learning_rate * (1.0 + jnp.log(grad_norm + eps)) # Clip learning rate to reasonable range effective_lr = jnp.clip(effective_lr, min_lr, max_lr) # Update weights momentum with adaptive learning rate new_weights = {} for k in params['weights'].keys(): grad = grads['weights'][k] # Update momentum with gradient clipping momentum['weights'][k] = ( self.beta * momentum['weights'][k] + (1 - self.beta) * jnp.clip(grad, -1.0, 1.0) ) # Apply update with weight decay weight_decay = 0.0001 * params['weights'][k] new_weights[k] = params['weights'][k] - effective_lr * ( momentum['weights'][k] + weight_decay ) # Update biases momentum momentum['biases'] = ( self.beta * momentum['biases'] + (1 - self.beta) * jnp.clip(grads['biases'], -1.0, 1.0) ) new_biases = params['biases'] - effective_lr * momentum['biases'] # Update layer norm parameters with smaller learning rate ln_lr = 0.1 * effective_lr # Slower updates for stability # Gamma (scale) momentum['gamma'] = ( self.beta * momentum['gamma'] + (1 - self.beta) * jnp.clip(grads['gamma'], -0.1, 0.1) ) new_gamma = params['gamma'] - ln_lr * momentum['gamma'] new_gamma = jnp.clip(new_gamma, 0.1, 10.0) # Prevent collapse # Beta (shift) momentum['beta'] = ( self.beta * momentum['beta'] + (1 - self.beta) * jnp.clip(grads['beta'], -0.1, 0.1) ) new_beta = params['beta'] - ln_lr * momentum['beta'] return { 'weights': new_weights, 'biases': new_biases, 'gamma': new_gamma, 'beta': new_beta }, loss_value return train_step def _mutate_genome(self, genome: Genome) -> Genome: """Mutate genome architecture.""" new_genome = genome.copy() # Mutate weights and biases for key in list(new_genome.params['weights'].keys()): if np.random.random() < 0.1: new_genome.params['weights'][key] += np.random.normal(0, 0.2) for key in list(new_genome.params['biases'].keys()): if np.random.random() < 0.1: new_genome.params['biases'][key] += np.random.normal(0, 0.2) return new_genome def _select_parent(self, fitnesses: List[float]) -> int: """Select parent using tournament selection.""" # Tournament selection tournament_size = 3 best_idx = np.random.randint(len(fitnesses)) best_fitness = fitnesses[best_idx] for _ in range(tournament_size - 1): idx = np.random.randint(len(fitnesses)) if fitnesses[idx] > best_fitness: best_idx = idx best_fitness = fitnesses[idx] return best_idx def _compute_fitness(self, network: Network, x: jnp.ndarray, y: jnp.ndarray, n_epochs: int = 100, batch_size: int = 32) -> float: """Compute fitness of network.""" n_samples = x.shape[0] best_loss = float('inf') best_accuracy = 0.0 # Initial prediction initial_pred = network.predict(x) initial_acc = float(jnp.mean((initial_pred == y))) # Train network no_improve = 0 for epoch in range(n_epochs): # Shuffle data perm = np.random.permutation(n_samples) x_shuffled = x[perm] y_shuffled = y[perm] # Train in batches epoch_losses = [] for i in range(0, n_samples, batch_size): batch_x = x_shuffled[i:min(i + batch_size, n_samples)] batch_y = y_shuffled[i:min(i + batch_size, n_samples)] # Train step network.params, loss = network._train_step(network.params, batch_x, batch_y) epoch_losses.append(float(loss)) # Update best loss avg_loss = float(np.mean(epoch_losses)) if avg_loss < best_loss: best_loss = avg_loss no_improve = 0 else: no_improve += 1 # Compute accuracy predictions = network.predict(x) accuracy = float(jnp.mean((predictions == y))) best_accuracy = max(best_accuracy, accuracy) # Print progress every 10 epochs if epoch % 10 == 0: print(f"Epoch {epoch}: Loss = {avg_loss:.4f}, Accuracy = {accuracy:.4f}") # Early stopping if good accuracy or no improvement if accuracy > 0.95 or no_improve >= 10: print(f"Early stopping at epoch {epoch}") print(f"Final accuracy: {accuracy:.4f}") break # Print improvement print(f"Network improved from {initial_acc:.4f} to {best_accuracy:.4f}") # Fitness based on accuracy fitness = best_accuracy return float(fitness) def evolve(self, x: jnp.ndarray, y: jnp.ndarray, n_generations: int = 50) -> Network: """Evolve network architectures.""" for generation in range(n_generations): print(f"\nGeneration {generation}") # Evaluate current population fitnesses = [] for network in self.population: fitness = self._compute_fitness(network, x, y) fitnesses.append(fitness) # Update best network if fitness > self.best_fitness: self.best_fitness = fitness self.best_network = Network(network.genome.copy()) print(f"New best fitness: {fitness:.4f}") # Create new population through selection and mutation new_population = [] # Keep best network (elitism) best_idx = np.argmax(fitnesses) new_population.append(Network(self.population[best_idx].genome.copy())) # Create rest of population while len(new_population) < self.population_size: # Select parent parent_idx = self._select_parent(fitnesses) parent = self.population[parent_idx].genome # Create child through mutation child_genome = self._mutate_genome(parent) new_population.append(Network(child_genome)) self.population = new_population return self.best_network