"""NEAT evolution implementation.""" |
import jax |
import jax.numpy as jnp |
import numpy as np |
from typing import List, Dict, Optional, Tuple, Callable |
from .network import Network |
from .genome import Genome |
class NEATEvolution: |
"""NEAT evolution implementation with structural mutations.""" |
'node_add_prob': 0.2, |
'conn_add_prob': 0.3, |
'weight_mutate_prob': 0.8, |
'weight_replace_prob': 0.1, |
'weight_perturb_size': 0.5, |
'bias_mutate_prob': 0.8, |
'bias_replace_prob': 0.1, |
'bias_perturb_size': 0.5, |
'complexity_coefficient': 0.0, |
'species_distance': 2.0, |
'species_elitism': 2, |
'survival_threshold': 0.3 |
} |
def __init__(self, |
n_inputs: int, |
n_outputs: int, |
population_size: int, |
config: Optional[Dict] = None, |
key: Optional[jnp.ndarray] = None): |
"""Initialize NEAT evolution. |
Args: |
n_inputs: Number of input nodes (12 for volleyball) |
n_outputs: Number of output nodes (3 for volleyball) |
population_size: Size of population |
config: Optional configuration parameters |
key: Random key for JAX |
""" |
self.n_inputs = n_inputs |
self.n_outputs = n_outputs |
self.population_size = population_size |
self.config = {**self.DEFAULT_CONFIG, **(config or {})} |
if key is None: |
self.key = jax.random.PRNGKey(0) |
else: |
self.key = key |
self.population = self._init_population() |
self.generation = 0 |
self.innovation_number = 0 |
self.species = [] |
def _init_population(self) -> List[Genome]: |
"""Initialize population with minimal networks.""" |
population = [] |
for _ in range(self.population_size): |
self.key, subkey = jax.random.split(self.key) |
genome = Genome(self.n_inputs, self.n_outputs, subkey) |
self.key, subkey = jax.random.split(self.key) |
n_hidden = int(jax.random.randint(subkey, (), 2, 7)) |
hidden_nodes = [] |
for _ in range(n_hidden): |
hidden_nodes.append(genome.add_node()) |
for i in range(self.n_inputs): |
for h in hidden_nodes: |
self.key, subkey = jax.random.split(self.key) |
if jax.random.uniform(subkey) < 0.5: |
self.key, subkey = jax.random.split(self.key) |
weight = jax.random.normal(subkey) * 0.5 |
genome.add_connection(i, h, weight) |
output_start = genome.n_nodes - self.n_outputs |
for h in hidden_nodes: |
for i in range(self.n_outputs): |
self.key, subkey = jax.random.split(self.key) |
if jax.random.uniform(subkey) < 0.5: |
self.key, subkey = jax.random.split(self.key) |
weight = jax.random.normal(subkey) * 0.5 |
genome.add_connection(h, output_start + i, weight) |
for i in range(self.n_inputs): |
for j in range(self.n_outputs): |
self.key, subkey = jax.random.split(self.key) |
if jax.random.uniform(subkey) < 0.3: |
self.key, subkey = jax.random.split(self.key) |
weight = jax.random.normal(subkey) * 0.3 |
genome.add_connection(i, output_start + j, weight) |
population.append(genome) |
return population |
def ask(self) -> List[Network]: |
"""Get current population as networks.""" |
return [Network(genome) for genome in self.population] |
def tell(self, fitnesses: List[float]) -> None: |
"""Update population based on fitness scores.""" |
sorted_pop = sorted(zip(self.population, fitnesses), |
key=lambda x: x[1], reverse=True) |
n_parents = max(1, int(self.population_size * self.config['survival_threshold'])) |
parents = [p for p, _ in sorted_pop[:n_parents]] |
if not parents: |
parents = [sorted_pop[0][0]] |
new_population = [parents[0]] |
while len(new_population) < self.population_size: |
parent = parents[0] if len(parents) == 1 else np.random.choice(parents) |
child = parent.copy() |
child = self._mutate_genome(child, self.key) |
new_population.append(child) |
self.population = new_population |
self.generation += 1 |
def _mutate_genome(self, genome: Genome, key: jnp.ndarray) -> Genome: |
"""Mutate a genome. |
Mutation types: |
1. Add new nodes (30% chance) |
2. Add new connections (50% chance) |
3. Modify weights (80% chance) |
4. Modify biases (70% chance) |
5. Enable/disable connections (20% chance) |
""" |
keys = jax.random.split(key, 6) |
if jax.random.uniform(keys[0]) < self.config['node_add_prob']: |
n_nodes = 1 |
while jax.random.uniform(keys[1]) < 0.3 and n_nodes < 4: |
enabled_conns = [(src, dst) for (src, dst), enabled in genome.connections.items() if enabled] |
if enabled_conns: |
src, dst = enabled_conns[int(jax.random.randint(keys[2], (), 0, len(enabled_conns)))] |
genome.add_node_between(src, dst) |
n_nodes += 1 |
if jax.random.uniform(keys[1]) < self.config['conn_add_prob']: |
n_conns = 0 |
max_attempts = 20 |
attempts = 0 |
while attempts < max_attempts and n_conns < 5: |
src = int(jax.random.randint(keys[2], (), 0, genome.n_nodes)) |
dst = int(jax.random.randint(keys[3], (), 0, genome.n_nodes)) |
if src != dst and (src, dst) not in genome.connections: |
weight = jax.random.normal(keys[4]) * 0.5 |
genome.add_connection(src, dst, weight) |
n_conns += 1 |
attempts += 1 |
if jax.random.uniform(keys[2]) < self.config['weight_mutate_prob']: |
for conn in list(genome.connections.keys()): |
if genome.connections[conn]: |
if jax.random.uniform(keys[3]) < self.config['weight_replace_prob']: |
genome.weights[conn] = jax.random.normal(keys[4]) * self.config['weight_perturb_size'] |
else: |
genome.weights[conn] += jax.random.normal(keys[4]) * self.config['weight_perturb_size'] |
if jax.random.uniform(keys[3]) < self.config['bias_mutate_prob']: |
for node in list(genome.biases.keys()): |
if jax.random.uniform(keys[4]) < self.config['bias_replace_prob']: |
genome.biases[node] = jax.random.normal(keys[5]) * self.config['bias_perturb_size'] |
else: |
genome.biases[node] += jax.random.normal(keys[5]) * self.config['bias_perturb_size'] |
for conn in list(genome.connections.keys()): |
if jax.random.uniform(keys[5]) < 0.2: |
genome.connections[conn] = not genome.connections[conn] |
return genome |
def get_average_nodes(self) -> float: |
"""Get average number of nodes in population.""" |
return np.mean([g.n_nodes for g in self.population]) |
def get_average_connections(self) -> float: |
"""Get average number of connections in population.""" |
return np.mean([len(g.connections) for g in self.population]) |
def get_activation_distribution(self) -> Dict[str, float]: |
"""Get distribution of activation functions in population. |
Returns: |
Dictionary mapping activation function names to their frequency |
""" |
return {'relu': 1.0} |
def run_evolution(self, evaluator: Callable[[Network], float], max_generations: int, |
fitness_threshold: float, reset_mutations: bool = True, |
max_stagnation: int = 15, verbose: bool = True) -> Tuple[Network, float]: |
"""Run the evolution process |
Args: |
evaluator: Function that takes a network and returns its fitness |
max_generations: Maximum number of generations to run |
fitness_threshold: Target fitness to achieve |
reset_mutations: Whether to reset mutations when fitness improves |
max_stagnation: Maximum generations without improvement before stopping |
verbose: Whether to print progress |
Returns: |
Tuple of (best network, best fitness) |
""" |
best_fitness = float('-inf') |
best_network = None |
stagnation_counter = 0 |
for generation in range(max_generations): |
fitnesses = [] |
for genome in self.population: |
network = genome.to_network() |
fitness = evaluator(network) |
genome.fitness = fitness |
fitnesses.append(fitness) |
if fitness > best_fitness: |
best_fitness = fitness |
best_network = network |
stagnation_counter = 0 |
if reset_mutations: |
self.reset_innovation() |
avg_fitness = sum(fitnesses) / len(fitnesses) |
generation_best = max(fitnesses) |
if verbose: |
print(f"\nGeneration {generation}:") |
print(f" Best Fitness: {best_fitness:.2f}") |
print(f" Generation Best: {generation_best:.2f}") |
print(f" Average Nodes: {self.get_average_nodes():.1f}") |
print(f" Average Connections: {self.get_average_connections():.1f}") |
if generation_best <= best_fitness: |
stagnation_counter += 1 |
else: |
stagnation_counter = 0 |
self.tell(fitnesses) |
if stagnation_counter >= max_stagnation: |
if verbose: |
print(f"\nStopping: No improvement for {max_stagnation} generations") |
break |
if verbose: |
print("\nTraining complete!") |
print(f"Best fitness achieved: {best_fitness:.2f}") |
return best_network, best_fitness |