"""NEAT evolution implementation."""
import jax
import jax.numpy as jnp
import numpy as np
from typing import List, Dict, Optional, Tuple, Callable
from .network import Network
from .genome import Genome
class NEATEvolution:
"""NEAT evolution implementation with structural mutations."""
'node_add_prob': 0.2, # Standard node addition rate
'conn_add_prob': 0.3, # Standard connection addition rate
'weight_mutate_prob': 0.8, # High chance of weight mutation
'weight_replace_prob': 0.1, # Low chance of complete weight replacement
'weight_perturb_size': 0.5, # Standard weight perturbation size
'bias_mutate_prob': 0.8, # High chance of bias mutation
'bias_replace_prob': 0.1, # Low chance of complete bias replacement
'bias_perturb_size': 0.5, # Standard bias perturbation size
'complexity_coefficient': 0.0, # No complexity penalty
'species_distance': 2.0, # Standard species distance
'species_elitism': 2, # Keep top 2 from each species
'survival_threshold': 0.3 # Keep 30% of population
def __init__(self,
n_inputs: int,
n_outputs: int,
population_size: int,
config: Optional[Dict] = None,
key: Optional[jnp.ndarray] = None):
"""Initialize NEAT evolution.
n_inputs: Number of input nodes (12 for volleyball)
n_outputs: Number of output nodes (3 for volleyball)
population_size: Size of population
config: Optional configuration parameters
key: Random key for JAX
self.n_inputs = n_inputs
self.n_outputs = n_outputs
self.population_size = population_size
self.config = {**self.DEFAULT_CONFIG, **(config or {})}
# Initialize random key
if key is None:
self.key = jax.random.PRNGKey(0)
self.key = key
# Initialize population
self.population = self._init_population()
self.generation = 0
self.innovation_number = 0
self.species = []
def _init_population(self) -> List[Genome]:
"""Initialize population with minimal networks."""
population = []
for _ in range(self.population_size):
# Split random key
self.key, subkey = jax.random.split(self.key)
# Create genome with proper input/output sizes
genome = Genome(self.n_inputs, self.n_outputs, subkey)
# Add random hidden nodes (between 2-6)
self.key, subkey = jax.random.split(self.key)
n_hidden = int(jax.random.randint(subkey, (), 2, 7))
hidden_nodes = []
for _ in range(n_hidden):
# Connect inputs to hidden with 50% probability
for i in range(self.n_inputs):
for h in hidden_nodes:
self.key, subkey = jax.random.split(self.key)
if jax.random.uniform(subkey) < 0.5:
self.key, subkey = jax.random.split(self.key)
weight = jax.random.normal(subkey) * 0.5
genome.add_connection(i, h, weight)
# Connect hidden to outputs with 50% probability
output_start = genome.n_nodes - self.n_outputs
for h in hidden_nodes:
for i in range(self.n_outputs):
self.key, subkey = jax.random.split(self.key)
if jax.random.uniform(subkey) < 0.5:
self.key, subkey = jax.random.split(self.key)
weight = jax.random.normal(subkey) * 0.5
genome.add_connection(h, output_start + i, weight)
# Add skip connections with 30% probability
for i in range(self.n_inputs):
for j in range(self.n_outputs):
self.key, subkey = jax.random.split(self.key)
if jax.random.uniform(subkey) < 0.3:
self.key, subkey = jax.random.split(self.key)
weight = jax.random.normal(subkey) * 0.3
genome.add_connection(i, output_start + j, weight)
return population
def ask(self) -> List[Network]:
"""Get current population as networks."""
return [Network(genome) for genome in self.population]
def tell(self, fitnesses: List[float]) -> None:
"""Update population based on fitness scores."""
# Sort population by fitness
sorted_pop = sorted(zip(self.population, fitnesses),
key=lambda x: x[1], reverse=True)
# For very small populations, keep at least one parent
n_parents = max(1, int(self.population_size * self.config['survival_threshold']))
parents = [p for p, _ in sorted_pop[:n_parents]]
# Ensure we have at least one parent
if not parents:
# If all fitnesses are equal (including all zeros), keep the first one
parents = [sorted_pop[0][0]]
# Create new population starting with the best performer
new_population = [parents[0]] # Always keep the best one
# Fill rest with mutated offspring
while len(new_population) < self.population_size:
# Select parent (with replacement)
parent = parents[0] if len(parents) == 1 else np.random.choice(parents)
child = parent.copy()
# Mutate child
child = self._mutate_genome(child, self.key)
self.population = new_population
self.generation += 1
def _mutate_genome(self, genome: Genome, key: jnp.ndarray) -> Genome:
"""Mutate a genome.
Mutation types:
1. Add new nodes (30% chance)
2. Add new connections (50% chance)
3. Modify weights (80% chance)
4. Modify biases (70% chance)
5. Enable/disable connections (20% chance)
# Split random key
keys = jax.random.split(key, 6)
# Add nodes
if jax.random.uniform(keys[0]) < self.config['node_add_prob']:
# Add 1-3 nodes with decreasing probability
n_nodes = 1
while jax.random.uniform(keys[1]) < 0.3 and n_nodes < 4:
# Pick random enabled connection
enabled_conns = [(src, dst) for (src, dst), enabled in genome.connections.items() if enabled]
if enabled_conns:
src, dst = enabled_conns[int(jax.random.randint(keys[2], (), 0, len(enabled_conns)))]
genome.add_node_between(src, dst)
n_nodes += 1
# Add connections
if jax.random.uniform(keys[1]) < self.config['conn_add_prob']:
# Add multiple connections with decreasing probability
n_conns = 0
max_attempts = 20 # Prevent infinite loops
attempts = 0
while attempts < max_attempts and n_conns < 5:
# Pick random nodes
src = int(jax.random.randint(keys[2], (), 0, genome.n_nodes))
dst = int(jax.random.randint(keys[3], (), 0, genome.n_nodes))
# Add connection if valid and not already present
if src != dst and (src, dst) not in genome.connections:
weight = jax.random.normal(keys[4]) * 0.5
genome.add_connection(src, dst, weight)
n_conns += 1
attempts += 1
# Mutate weights
if jax.random.uniform(keys[2]) < self.config['weight_mutate_prob']:
for conn in list(genome.connections.keys()):
if genome.connections[conn]: # Only mutate enabled connections
if jax.random.uniform(keys[3]) < self.config['weight_replace_prob']:
# Reset weight
genome.weights[conn] = jax.random.normal(keys[4]) * self.config['weight_perturb_size']
# Perturb weight
genome.weights[conn] += jax.random.normal(keys[4]) * self.config['weight_perturb_size']
# Mutate biases
if jax.random.uniform(keys[3]) < self.config['bias_mutate_prob']:
for node in list(genome.biases.keys()):
if jax.random.uniform(keys[4]) < self.config['bias_replace_prob']:
# Reset bias
genome.biases[node] = jax.random.normal(keys[5]) * self.config['bias_perturb_size']
# Perturb bias
genome.biases[node] += jax.random.normal(keys[5]) * self.config['bias_perturb_size']
# Enable/disable connections
for conn in list(genome.connections.keys()):
if jax.random.uniform(keys[5]) < 0.2: # 20% chance per connection
genome.connections[conn] = not genome.connections[conn]
return genome
def get_average_nodes(self) -> float:
"""Get average number of nodes in population."""
return np.mean([g.n_nodes for g in self.population])
def get_average_connections(self) -> float:
"""Get average number of connections in population."""
return np.mean([len(g.connections) for g in self.population])
def get_activation_distribution(self) -> Dict[str, float]:
"""Get distribution of activation functions in population.
Dictionary mapping activation function names to their frequency
# For now we only use ReLU
return {'relu': 1.0}
def run_evolution(self, evaluator: Callable[[Network], float], max_generations: int,
fitness_threshold: float, reset_mutations: bool = True,
max_stagnation: int = 15, verbose: bool = True) -> Tuple[Network, float]:
"""Run the evolution process
evaluator: Function that takes a network and returns its fitness
max_generations: Maximum number of generations to run
fitness_threshold: Target fitness to achieve
reset_mutations: Whether to reset mutations when fitness improves
max_stagnation: Maximum generations without improvement before stopping
verbose: Whether to print progress
Tuple of (best network, best fitness)
best_fitness = float('-inf')
best_network = None
stagnation_counter = 0
for generation in range(max_generations):
# Evaluate current population
fitnesses = []
for genome in self.population:
network = genome.to_network()
fitness = evaluator(network) = fitness
# Update best if improved
if fitness > best_fitness:
best_fitness = fitness
best_network = network
stagnation_counter = 0
if reset_mutations:
# Get statistics
avg_fitness = sum(fitnesses) / len(fitnesses)
generation_best = max(fitnesses)
# Print progress
if verbose:
print(f"\nGeneration {generation}:")
print(f" Best Fitness: {best_fitness:.2f}")
print(f" Generation Best: {generation_best:.2f}")
print(f" Average Nodes: {self.get_average_nodes():.1f}")
print(f" Average Connections: {self.get_average_connections():.1f}")
# Check for improvement
if generation_best <= best_fitness:
stagnation_counter += 1
stagnation_counter = 0
# Create next generation
# Stop if stagnated too long
if stagnation_counter >= max_stagnation:
if verbose:
print(f"\nStopping: No improvement for {max_stagnation} generations")
if verbose:
print("\nTraining complete!")
print(f"Best fitness achieved: {best_fitness:.2f}")
return best_network, best_fitness