|
"""NEAT evolution implementation.""" |
|
|
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
from typing import List, Dict, Optional, Tuple, Callable |
|
from .network import Network |
|
from .genome import Genome |
|
|
|
class NEATEvolution: |
|
"""NEAT evolution implementation with structural mutations.""" |
|
|
|
DEFAULT_CONFIG = { |
|
'node_add_prob': 0.2, |
|
'conn_add_prob': 0.3, |
|
'weight_mutate_prob': 0.8, |
|
'weight_replace_prob': 0.1, |
|
'weight_perturb_size': 0.5, |
|
'bias_mutate_prob': 0.8, |
|
'bias_replace_prob': 0.1, |
|
'bias_perturb_size': 0.5, |
|
'complexity_coefficient': 0.0, |
|
'species_distance': 2.0, |
|
'species_elitism': 2, |
|
'survival_threshold': 0.3 |
|
} |
|
|
|
def __init__(self, |
|
n_inputs: int, |
|
n_outputs: int, |
|
population_size: int, |
|
config: Optional[Dict] = None, |
|
key: Optional[jnp.ndarray] = None): |
|
"""Initialize NEAT evolution. |
|
|
|
Args: |
|
n_inputs: Number of input nodes (12 for volleyball) |
|
n_outputs: Number of output nodes (3 for volleyball) |
|
population_size: Size of population |
|
config: Optional configuration parameters |
|
key: Random key for JAX |
|
""" |
|
self.n_inputs = n_inputs |
|
self.n_outputs = n_outputs |
|
self.population_size = population_size |
|
self.config = {**self.DEFAULT_CONFIG, **(config or {})} |
|
|
|
|
|
if key is None: |
|
self.key = jax.random.PRNGKey(0) |
|
else: |
|
self.key = key |
|
|
|
|
|
self.population = self._init_population() |
|
self.generation = 0 |
|
self.innovation_number = 0 |
|
self.species = [] |
|
|
|
def _init_population(self) -> List[Genome]: |
|
"""Initialize population with minimal networks.""" |
|
population = [] |
|
for _ in range(self.population_size): |
|
|
|
self.key, subkey = jax.random.split(self.key) |
|
|
|
|
|
genome = Genome(self.n_inputs, self.n_outputs, subkey) |
|
|
|
|
|
self.key, subkey = jax.random.split(self.key) |
|
n_hidden = int(jax.random.randint(subkey, (), 2, 7)) |
|
|
|
hidden_nodes = [] |
|
for _ in range(n_hidden): |
|
hidden_nodes.append(genome.add_node()) |
|
|
|
|
|
for i in range(self.n_inputs): |
|
for h in hidden_nodes: |
|
self.key, subkey = jax.random.split(self.key) |
|
if jax.random.uniform(subkey) < 0.5: |
|
self.key, subkey = jax.random.split(self.key) |
|
weight = jax.random.normal(subkey) * 0.5 |
|
genome.add_connection(i, h, weight) |
|
|
|
|
|
output_start = genome.n_nodes - self.n_outputs |
|
for h in hidden_nodes: |
|
for i in range(self.n_outputs): |
|
self.key, subkey = jax.random.split(self.key) |
|
if jax.random.uniform(subkey) < 0.5: |
|
self.key, subkey = jax.random.split(self.key) |
|
weight = jax.random.normal(subkey) * 0.5 |
|
genome.add_connection(h, output_start + i, weight) |
|
|
|
|
|
for i in range(self.n_inputs): |
|
for j in range(self.n_outputs): |
|
self.key, subkey = jax.random.split(self.key) |
|
if jax.random.uniform(subkey) < 0.3: |
|
self.key, subkey = jax.random.split(self.key) |
|
weight = jax.random.normal(subkey) * 0.3 |
|
genome.add_connection(i, output_start + j, weight) |
|
|
|
population.append(genome) |
|
return population |
|
|
|
def ask(self) -> List[Network]: |
|
"""Get current population as networks.""" |
|
return [Network(genome) for genome in self.population] |
|
|
|
def tell(self, fitnesses: List[float]) -> None: |
|
"""Update population based on fitness scores.""" |
|
|
|
sorted_pop = sorted(zip(self.population, fitnesses), |
|
key=lambda x: x[1], reverse=True) |
|
|
|
|
|
n_parents = max(1, int(self.population_size * self.config['survival_threshold'])) |
|
parents = [p for p, _ in sorted_pop[:n_parents]] |
|
|
|
|
|
if not parents: |
|
|
|
parents = [sorted_pop[0][0]] |
|
|
|
|
|
new_population = [parents[0]] |
|
|
|
|
|
while len(new_population) < self.population_size: |
|
|
|
parent = parents[0] if len(parents) == 1 else np.random.choice(parents) |
|
child = parent.copy() |
|
|
|
|
|
child = self._mutate_genome(child, self.key) |
|
|
|
new_population.append(child) |
|
|
|
self.population = new_population |
|
self.generation += 1 |
|
|
|
def _mutate_genome(self, genome: Genome, key: jnp.ndarray) -> Genome: |
|
"""Mutate a genome. |
|
|
|
Mutation types: |
|
1. Add new nodes (30% chance) |
|
2. Add new connections (50% chance) |
|
3. Modify weights (80% chance) |
|
4. Modify biases (70% chance) |
|
5. Enable/disable connections (20% chance) |
|
""" |
|
|
|
keys = jax.random.split(key, 6) |
|
|
|
|
|
if jax.random.uniform(keys[0]) < self.config['node_add_prob']: |
|
|
|
n_nodes = 1 |
|
while jax.random.uniform(keys[1]) < 0.3 and n_nodes < 4: |
|
|
|
enabled_conns = [(src, dst) for (src, dst), enabled in genome.connections.items() if enabled] |
|
if enabled_conns: |
|
src, dst = enabled_conns[int(jax.random.randint(keys[2], (), 0, len(enabled_conns)))] |
|
genome.add_node_between(src, dst) |
|
n_nodes += 1 |
|
|
|
|
|
if jax.random.uniform(keys[1]) < self.config['conn_add_prob']: |
|
|
|
n_conns = 0 |
|
max_attempts = 20 |
|
attempts = 0 |
|
|
|
while attempts < max_attempts and n_conns < 5: |
|
|
|
src = int(jax.random.randint(keys[2], (), 0, genome.n_nodes)) |
|
dst = int(jax.random.randint(keys[3], (), 0, genome.n_nodes)) |
|
|
|
|
|
if src != dst and (src, dst) not in genome.connections: |
|
weight = jax.random.normal(keys[4]) * 0.5 |
|
genome.add_connection(src, dst, weight) |
|
n_conns += 1 |
|
attempts += 1 |
|
|
|
|
|
if jax.random.uniform(keys[2]) < self.config['weight_mutate_prob']: |
|
for conn in list(genome.connections.keys()): |
|
if genome.connections[conn]: |
|
if jax.random.uniform(keys[3]) < self.config['weight_replace_prob']: |
|
|
|
genome.weights[conn] = jax.random.normal(keys[4]) * self.config['weight_perturb_size'] |
|
else: |
|
|
|
genome.weights[conn] += jax.random.normal(keys[4]) * self.config['weight_perturb_size'] |
|
|
|
|
|
if jax.random.uniform(keys[3]) < self.config['bias_mutate_prob']: |
|
for node in list(genome.biases.keys()): |
|
if jax.random.uniform(keys[4]) < self.config['bias_replace_prob']: |
|
|
|
genome.biases[node] = jax.random.normal(keys[5]) * self.config['bias_perturb_size'] |
|
else: |
|
|
|
genome.biases[node] += jax.random.normal(keys[5]) * self.config['bias_perturb_size'] |
|
|
|
|
|
for conn in list(genome.connections.keys()): |
|
if jax.random.uniform(keys[5]) < 0.2: |
|
genome.connections[conn] = not genome.connections[conn] |
|
|
|
return genome |
|
|
|
def get_average_nodes(self) -> float: |
|
"""Get average number of nodes in population.""" |
|
return np.mean([g.n_nodes for g in self.population]) |
|
|
|
def get_average_connections(self) -> float: |
|
"""Get average number of connections in population.""" |
|
return np.mean([len(g.connections) for g in self.population]) |
|
|
|
def get_activation_distribution(self) -> Dict[str, float]: |
|
"""Get distribution of activation functions in population. |
|
|
|
Returns: |
|
Dictionary mapping activation function names to their frequency |
|
""" |
|
|
|
return {'relu': 1.0} |
|
|
|
def run_evolution(self, evaluator: Callable[[Network], float], max_generations: int, |
|
fitness_threshold: float, reset_mutations: bool = True, |
|
max_stagnation: int = 15, verbose: bool = True) -> Tuple[Network, float]: |
|
"""Run the evolution process |
|
|
|
Args: |
|
evaluator: Function that takes a network and returns its fitness |
|
max_generations: Maximum number of generations to run |
|
fitness_threshold: Target fitness to achieve |
|
reset_mutations: Whether to reset mutations when fitness improves |
|
max_stagnation: Maximum generations without improvement before stopping |
|
verbose: Whether to print progress |
|
|
|
Returns: |
|
Tuple of (best network, best fitness) |
|
""" |
|
best_fitness = float('-inf') |
|
best_network = None |
|
stagnation_counter = 0 |
|
|
|
for generation in range(max_generations): |
|
|
|
fitnesses = [] |
|
for genome in self.population: |
|
network = genome.to_network() |
|
fitness = evaluator(network) |
|
genome.fitness = fitness |
|
fitnesses.append(fitness) |
|
|
|
|
|
if fitness > best_fitness: |
|
best_fitness = fitness |
|
best_network = network |
|
stagnation_counter = 0 |
|
if reset_mutations: |
|
self.reset_innovation() |
|
|
|
|
|
avg_fitness = sum(fitnesses) / len(fitnesses) |
|
generation_best = max(fitnesses) |
|
|
|
|
|
if verbose: |
|
print(f"\nGeneration {generation}:") |
|
print(f" Best Fitness: {best_fitness:.2f}") |
|
print(f" Generation Best: {generation_best:.2f}") |
|
print(f" Average Nodes: {self.get_average_nodes():.1f}") |
|
print(f" Average Connections: {self.get_average_connections():.1f}") |
|
|
|
|
|
if generation_best <= best_fitness: |
|
stagnation_counter += 1 |
|
else: |
|
stagnation_counter = 0 |
|
|
|
|
|
self.tell(fitnesses) |
|
|
|
|
|
if stagnation_counter >= max_stagnation: |
|
if verbose: |
|
print(f"\nStopping: No improvement for {max_stagnation} generations") |
|
break |
|
|
|
if verbose: |
|
print("\nTraining complete!") |
|
print(f"Best fitness achieved: {best_fitness:.2f}") |
|
|
|
return best_network, best_fitness |
|
|