"""NEAT evolution implementation.""" import jax import jax.numpy as jnp import numpy as np from typing import List, Dict, Optional, Tuple, Callable from .network import Network from .genome import Genome class NEATEvolution: """NEAT evolution implementation with structural mutations.""" DEFAULT_CONFIG = { 'node_add_prob': 0.2, # Standard node addition rate 'conn_add_prob': 0.3, # Standard connection addition rate 'weight_mutate_prob': 0.8, # High chance of weight mutation 'weight_replace_prob': 0.1, # Low chance of complete weight replacement 'weight_perturb_size': 0.5, # Standard weight perturbation size 'bias_mutate_prob': 0.8, # High chance of bias mutation 'bias_replace_prob': 0.1, # Low chance of complete bias replacement 'bias_perturb_size': 0.5, # Standard bias perturbation size 'complexity_coefficient': 0.0, # No complexity penalty 'species_distance': 2.0, # Standard species distance 'species_elitism': 2, # Keep top 2 from each species 'survival_threshold': 0.3 # Keep 30% of population } def __init__(self, n_inputs: int, n_outputs: int, population_size: int, config: Optional[Dict] = None, key: Optional[jnp.ndarray] = None): """Initialize NEAT evolution. Args: n_inputs: Number of input nodes (12 for volleyball) n_outputs: Number of output nodes (3 for volleyball) population_size: Size of population config: Optional configuration parameters key: Random key for JAX """ self.n_inputs = n_inputs self.n_outputs = n_outputs self.population_size = population_size self.config = {**self.DEFAULT_CONFIG, **(config or {})} # Initialize random key if key is None: self.key = jax.random.PRNGKey(0) else: self.key = key # Initialize population self.population = self._init_population() self.generation = 0 self.innovation_number = 0 self.species = [] def _init_population(self) -> List[Genome]: """Initialize population with minimal networks.""" population = [] for _ in range(self.population_size): # Split random key self.key, subkey = jax.random.split(self.key) # Create genome with proper input/output sizes genome = Genome(self.n_inputs, self.n_outputs, subkey) # Add random hidden nodes (between 2-6) self.key, subkey = jax.random.split(self.key) n_hidden = int(jax.random.randint(subkey, (), 2, 7)) hidden_nodes = [] for _ in range(n_hidden): hidden_nodes.append(genome.add_node()) # Connect inputs to hidden with 50% probability for i in range(self.n_inputs): for h in hidden_nodes: self.key, subkey = jax.random.split(self.key) if jax.random.uniform(subkey) < 0.5: self.key, subkey = jax.random.split(self.key) weight = jax.random.normal(subkey) * 0.5 genome.add_connection(i, h, weight) # Connect hidden to outputs with 50% probability output_start = genome.n_nodes - self.n_outputs for h in hidden_nodes: for i in range(self.n_outputs): self.key, subkey = jax.random.split(self.key) if jax.random.uniform(subkey) < 0.5: self.key, subkey = jax.random.split(self.key) weight = jax.random.normal(subkey) * 0.5 genome.add_connection(h, output_start + i, weight) # Add skip connections with 30% probability for i in range(self.n_inputs): for j in range(self.n_outputs): self.key, subkey = jax.random.split(self.key) if jax.random.uniform(subkey) < 0.3: self.key, subkey = jax.random.split(self.key) weight = jax.random.normal(subkey) * 0.3 genome.add_connection(i, output_start + j, weight) population.append(genome) return population def ask(self) -> List[Network]: """Get current population as networks.""" return [Network(genome) for genome in self.population] def tell(self, fitnesses: List[float]) -> None: """Update population based on fitness scores.""" # Sort population by fitness sorted_pop = sorted(zip(self.population, fitnesses), key=lambda x: x[1], reverse=True) # For very small populations, keep at least one parent n_parents = max(1, int(self.population_size * self.config['survival_threshold'])) parents = [p for p, _ in sorted_pop[:n_parents]] # Ensure we have at least one parent if not parents: # If all fitnesses are equal (including all zeros), keep the first one parents = [sorted_pop[0][0]] # Create new population starting with the best performer new_population = [parents[0]] # Always keep the best one # Fill rest with mutated offspring while len(new_population) < self.population_size: # Select parent (with replacement) parent = parents[0] if len(parents) == 1 else np.random.choice(parents) child = parent.copy() # Mutate child child = self._mutate_genome(child, self.key) new_population.append(child) self.population = new_population self.generation += 1 def _mutate_genome(self, genome: Genome, key: jnp.ndarray) -> Genome: """Mutate a genome. Mutation types: 1. Add new nodes (30% chance) 2. Add new connections (50% chance) 3. Modify weights (80% chance) 4. Modify biases (70% chance) 5. Enable/disable connections (20% chance) """ # Split random key keys = jax.random.split(key, 6) # Add nodes if jax.random.uniform(keys[0]) < self.config['node_add_prob']: # Add 1-3 nodes with decreasing probability n_nodes = 1 while jax.random.uniform(keys[1]) < 0.3 and n_nodes < 4: # Pick random enabled connection enabled_conns = [(src, dst) for (src, dst), enabled in genome.connections.items() if enabled] if enabled_conns: src, dst = enabled_conns[int(jax.random.randint(keys[2], (), 0, len(enabled_conns)))] genome.add_node_between(src, dst) n_nodes += 1 # Add connections if jax.random.uniform(keys[1]) < self.config['conn_add_prob']: # Add multiple connections with decreasing probability n_conns = 0 max_attempts = 20 # Prevent infinite loops attempts = 0 while attempts < max_attempts and n_conns < 5: # Pick random nodes src = int(jax.random.randint(keys[2], (), 0, genome.n_nodes)) dst = int(jax.random.randint(keys[3], (), 0, genome.n_nodes)) # Add connection if valid and not already present if src != dst and (src, dst) not in genome.connections: weight = jax.random.normal(keys[4]) * 0.5 genome.add_connection(src, dst, weight) n_conns += 1 attempts += 1 # Mutate weights if jax.random.uniform(keys[2]) < self.config['weight_mutate_prob']: for conn in list(genome.connections.keys()): if genome.connections[conn]: # Only mutate enabled connections if jax.random.uniform(keys[3]) < self.config['weight_replace_prob']: # Reset weight genome.weights[conn] = jax.random.normal(keys[4]) * self.config['weight_perturb_size'] else: # Perturb weight genome.weights[conn] += jax.random.normal(keys[4]) * self.config['weight_perturb_size'] # Mutate biases if jax.random.uniform(keys[3]) < self.config['bias_mutate_prob']: for node in list(genome.biases.keys()): if jax.random.uniform(keys[4]) < self.config['bias_replace_prob']: # Reset bias genome.biases[node] = jax.random.normal(keys[5]) * self.config['bias_perturb_size'] else: # Perturb bias genome.biases[node] += jax.random.normal(keys[5]) * self.config['bias_perturb_size'] # Enable/disable connections for conn in list(genome.connections.keys()): if jax.random.uniform(keys[5]) < 0.2: # 20% chance per connection genome.connections[conn] = not genome.connections[conn] return genome def get_average_nodes(self) -> float: """Get average number of nodes in population.""" return np.mean([g.n_nodes for g in self.population]) def get_average_connections(self) -> float: """Get average number of connections in population.""" return np.mean([len(g.connections) for g in self.population]) def get_activation_distribution(self) -> Dict[str, float]: """Get distribution of activation functions in population. Returns: Dictionary mapping activation function names to their frequency """ # For now we only use ReLU return {'relu': 1.0} def run_evolution(self, evaluator: Callable[[Network], float], max_generations: int, fitness_threshold: float, reset_mutations: bool = True, max_stagnation: int = 15, verbose: bool = True) -> Tuple[Network, float]: """Run the evolution process Args: evaluator: Function that takes a network and returns its fitness max_generations: Maximum number of generations to run fitness_threshold: Target fitness to achieve reset_mutations: Whether to reset mutations when fitness improves max_stagnation: Maximum generations without improvement before stopping verbose: Whether to print progress Returns: Tuple of (best network, best fitness) """ best_fitness = float('-inf') best_network = None stagnation_counter = 0 for generation in range(max_generations): # Evaluate current population fitnesses = [] for genome in self.population: network = genome.to_network() fitness = evaluator(network) genome.fitness = fitness fitnesses.append(fitness) # Update best if improved if fitness > best_fitness: best_fitness = fitness best_network = network stagnation_counter = 0 if reset_mutations: self.reset_innovation() # Get statistics avg_fitness = sum(fitnesses) / len(fitnesses) generation_best = max(fitnesses) # Print progress if verbose: print(f"\nGeneration {generation}:") print(f" Best Fitness: {best_fitness:.2f}") print(f" Generation Best: {generation_best:.2f}") print(f" Average Nodes: {self.get_average_nodes():.1f}") print(f" Average Connections: {self.get_average_connections():.1f}") # Check for improvement if generation_best <= best_fitness: stagnation_counter += 1 else: stagnation_counter = 0 # Create next generation self.tell(fitnesses) # Stop if stagnated too long if stagnation_counter >= max_stagnation: if verbose: print(f"\nStopping: No improvement for {max_stagnation} generations") break if verbose: print("\nTraining complete!") print(f"Best fitness achieved: {best_fitness:.2f}") return best_network, best_fitness