"""NEAT Genome implementation. |
This module implements the core NEAT genome structure and operations. |
Each genome represents a neural network with nodes (neurons) and connections (synapses). |
The genome can be mutated to evolve the network structure and weights over time. |
""" |
from dataclasses import dataclass |
import jax.numpy as jnp |
import jax.random as jrandom |
from typing import Dict, List, Tuple, Optional |
import time |
import random |
import numpy as np |
@dataclass |
class NodeGene: |
"""Node gene containing activation function and type. |
Attributes: |
node_id: Unique identifier for this node |
node_type: Type of node ('input', 'hidden', 'recurrent', or 'output') |
activation: Activation function ('tanh', 'relu', 'sigmoid', or 'linear') |
""" |
node_id: int |
node_type: str |
activation: str |
@dataclass |
class ConnectionGene: |
"""Connection gene containing connection properties. |
Attributes: |
source: ID of source node |
target: ID of target node |
weight: Connection weight |
enabled: Whether connection is enabled |
innovation: Unique innovation number for this connection |
""" |
source: int |
target: int |
weight: float |
enabled: bool = True |
innovation: int = 0 |
class Genome: |
"""NEAT Genome implementation. |
A genome represents a neural network as a collection of node and connection genes. |
The network topology can be modified through mutation operations. |
Attributes: |
input_size: Number of input nodes |
output_size: Number of output nodes |
node_genes: Dictionary mapping node IDs to NodeGene objects |
connection_genes: List of ConnectionGene objects |
key: Random key for reproducible randomness |
innovation_number: Counter for assigning unique innovation numbers |
""" |
def __init__(self, input_size: int, output_size: int): |
"""Initialize genome with specified number of inputs and outputs. |
Args: |
input_size: Number of input nodes |
output_size: Number of output nodes (must be 3 for volleyball) |
""" |
self.input_size = input_size |
self.output_size = output_size |
self.node_genes: Dict[int, NodeGene] = {} |
self.connection_genes: List[ConnectionGene] = [] |
timestamp = int(time.time() * 1000) |
self.key = jrandom.PRNGKey(hash((input_size, output_size, timestamp)) % (2**32)) |
self.innovation_number = 0 |
self._init_minimal() |
def _init_minimal(self): |
"""Initialize minimal feed-forward network structure. |
Network structure: |
- Input nodes [0-7]: Game state inputs |
- Hidden layer 1 [8-15]: First processing layer (8 nodes) |
- Hidden layer 2 [16-23]: Second processing layer (8 nodes) |
- Output nodes [24-26]: Action outputs (left, right, jump) |
Using larger initial weights for faster learning: |
- Input->Hidden1: N(0, 2.0) for strong initial responses |
- Hidden1->Hidden2: N(0, 2.0) for feature processing |
- Hidden2->Output: N(0, 4.0) for decisive actions |
""" |
for i in range(8): |
self.node_genes[i] = NodeGene( |
node_id=i, |
node_type='input', |
activation='linear' |
) |
hidden1_size = 8 |
hidden1_start = 8 |
for i in range(hidden1_size): |
node_id = hidden1_start + i |
self.node_genes[node_id] = NodeGene( |
node_id=node_id, |
node_type='hidden', |
activation='relu' |
) |
for input_id in range(8): |
weight = float(jrandom.normal(self.key) * 2.0) |
self.connection_genes.append(ConnectionGene( |
source=input_id, |
target=node_id, |
weight=weight, |
enabled=True, |
innovation=self.innovation_number |
)) |
self.innovation_number += 1 |
hidden2_size = 8 |
hidden2_start = hidden1_start + hidden1_size |
for i in range(hidden2_size): |
node_id = hidden2_start + i |
self.node_genes[node_id] = NodeGene( |
node_id=node_id, |
node_type='hidden', |
activation='relu' |
) |
for h1_id in range(hidden1_start, hidden1_start + hidden1_size): |
weight = float(jrandom.normal(self.key) * 2.0) |
self.connection_genes.append(ConnectionGene( |
source=h1_id, |
target=node_id, |
weight=weight, |
enabled=True, |
innovation=self.innovation_number |
)) |
self.innovation_number += 1 |
output_start = hidden2_start + hidden2_size |
for i in range(self.output_size): |
node_id = output_start + i |
self.node_genes[node_id] = NodeGene( |
node_id=node_id, |
node_type='output', |
activation='tanh' |
) |
for h2_id in range(hidden2_start, hidden2_start + hidden2_size): |
weight = float(jrandom.normal(self.key) * 4.0) |
self.connection_genes.append(ConnectionGene( |
source=h2_id, |
target=node_id, |
weight=weight, |
enabled=True, |
innovation=self.innovation_number |
)) |
self.innovation_number += 1 |
def mutate(self, config: Dict): |
"""Mutate the genome by modifying weights and network structure. |
Args: |
config: Dictionary containing mutation parameters: |
- weight_mutation_rate: Probability of mutating each weight |
- weight_mutation_power: Standard deviation for weight mutations |
- add_node_rate: Probability of adding a new node |
- add_connection_rate: Probability of adding a new connection |
""" |
for conn in self.connection_genes: |
if jrandom.uniform(self.key) < config['weight_mutation_rate']: |
self.key, subkey = jrandom.split(self.key) |
conn.weight += float(jrandom.normal(subkey) * config['weight_mutation_power']) |
if config['add_node_rate'] > 0: |
if jrandom.uniform(self.key) < config['add_node_rate']: |
self._add_node() |
if config['add_connection_rate'] > 0: |
if jrandom.uniform(self.key) < config['add_connection_rate']: |
self._add_connection() |
def _add_node(self): |
"""Add a new node by splitting an existing connection.""" |
if not self.connection_genes: |
return |
conn_to_split = np.random.choice(self.connection_genes) |
conn_to_split.enabled = False |
new_node_id = max(self.node_genes.keys()) + 1 |
self.node_genes[new_node_id] = NodeGene( |
node_id=new_node_id, |
node_type='hidden', |
activation='relu' |
) |
self.connection_genes.extend([ |
ConnectionGene( |
source=conn_to_split.source, |
target=new_node_id, |
weight=1.0, |
enabled=True, |
innovation=self.innovation_number |
), |
ConnectionGene( |
source=new_node_id, |
target=conn_to_split.target, |
weight=conn_to_split.weight, |
enabled=True, |
innovation=self.innovation_number + 1 |
) |
]) |
self.innovation_number += 2 |
def _add_connection(self): |
"""Add a new connection between two unconnected nodes.""" |
existing_connections = {(c.source, c.target) for c in self.connection_genes} |
possible_connections = [] |
for source in self.node_genes: |
for target in self.node_genes: |
if (source, target) in existing_connections: |
continue |
if self.node_genes[source].node_type != 'recurrent' and \ |
self.would_create_cycle(source, target): |
continue |
possible_connections.append((source, target)) |
if possible_connections: |
source, target = random.choice(possible_connections) |
weight = float(jrandom.normal(self.key) * 1.0) |
self.connection_genes.append(ConnectionGene( |
source=source, |
target=target, |
weight=weight, |
enabled=True, |
innovation=self.innovation_number |
)) |
self.innovation_number += 1 |
def would_create_cycle(self, source: int, target: int) -> bool: |
"""Check if adding connection would create cycle in network. |
Args: |
source: Source node ID |
target: Target node ID |
Returns: |
True if connection would create cycle, False otherwise |
""" |
if self.node_genes[source].node_type == 'recurrent' or \ |
self.node_genes[target].node_type == 'recurrent': |
return False |
visited = set() |
def dfs(node: int) -> bool: |
if node == source: |
return True |
if node in visited: |
return False |
visited.add(node) |
for conn in self.connection_genes: |
if conn.source == node and conn.enabled: |
if dfs(conn.target): |
return True |
return False |
return dfs(target) |
def add_node_between(self, source: int, target: int): |
"""Add a new node between two nodes, splitting an existing connection. |
Args: |
source: Source node ID |
target: Target node ID |
""" |
for conn in self.connection_genes: |
if conn.source == source and conn.target == target and conn.enabled: |
conn.enabled = False |
new_id = max(self.node_genes.keys()) + 1 |
self.node_genes[new_id] = NodeGene( |
node_id=new_id, |
node_type='hidden', |
activation='relu' |
) |
self.connection_genes.extend([ |
ConnectionGene( |
source=source, |
target=new_id, |
weight=1.0, |
enabled=True, |
innovation=self.innovation_number |
), |
ConnectionGene( |
source=new_id, |
target=target, |
weight=conn.weight, |
enabled=True, |
innovation=self.innovation_number + 1 |
) |
]) |
self.innovation_number += 2 |
break |
def add_connection(self, source: int, target: int, weight: Optional[float] = None) -> bool: |
"""Add a new connection between two nodes. |
Args: |
source: Source node ID |
target: Target node ID |
weight: Optional connection weight. If None, a random weight is generated. |
Returns: |
True if connection was added, False if invalid or already exists |
""" |
if any(c.source == source and c.target == target for c in self.connection_genes): |
return False |
if source not in self.node_genes or target not in self.node_genes: |
return False |
if source >= target: |
return False |
if weight is None: |
weight = float(jrandom.normal(self.key) * 1.0) |
self.connection_genes.append(ConnectionGene( |
source=source, |
target=target, |
weight=weight, |
enabled=True, |
innovation=self.innovation_number |
)) |
self.innovation_number += 1 |
return True |
def crossover(self, other: 'Genome', key: jnp.ndarray) -> 'Genome': |
"""Perform crossover between two genomes. |
Args: |
other: Other parent genome |
key: JAX PRNG key |
Returns: |
Child genome |
""" |
child = Genome(self.input_size, self.output_size) |
for node_id in self.node_genes: |
if node_id in other.node_genes: |
if jrandom.uniform(key) < 0.5: |
child.node_genes[node_id] = self.node_genes[node_id] |
else: |
child.node_genes[node_id] = other.node_genes[node_id] |
else: |
child.node_genes[node_id] = self.node_genes[node_id] |
for conn in self.connection_genes: |
if conn.innovation in [c.innovation for c in other.connection_genes]: |
other_conn = next(c for c in other.connection_genes if c.innovation == conn.innovation) |
if jrandom.uniform(key) < 0.5: |
child.connection_genes.append(ConnectionGene( |
source=conn.source, |
target=conn.target, |
weight=conn.weight, |
enabled=conn.enabled, |
innovation=conn.innovation |
)) |
else: |
child.connection_genes.append(ConnectionGene( |
source=other_conn.source, |
target=other_conn.target, |
weight=other_conn.weight, |
enabled=other_conn.enabled, |
innovation=other_conn.innovation |
)) |
else: |
child.connection_genes.append(ConnectionGene( |
source=conn.source, |
target=conn.target, |
weight=conn.weight, |
enabled=conn.enabled, |
innovation=conn.innovation |
)) |
return child |
def clone(self) -> 'Genome': |
"""Create a copy of this genome. |
Returns: |
Copy of genome |
""" |
clone = Genome(self.input_size, self.output_size) |
clone.node_genes = self.node_genes.copy() |
clone.connection_genes = [ConnectionGene(**conn.__dict__) for conn in self.connection_genes] |
return clone |
@property |
def n_nodes(self) -> int: |
"""Get total number of nodes in the genome.""" |
return len(self.node_genes) |