|
"""NEAT Genome implementation. |
|
|
|
This module implements the core NEAT genome structure and operations. |
|
Each genome represents a neural network with nodes (neurons) and connections (synapses). |
|
The genome can be mutated to evolve the network structure and weights over time. |
|
""" |
|
|
|
from dataclasses import dataclass |
|
import jax.numpy as jnp |
|
import jax.random as jrandom |
|
from typing import Dict, List, Tuple, Optional |
|
import time |
|
import random |
|
import numpy as np |
|
|
|
@dataclass |
|
class NodeGene: |
|
"""Node gene containing activation function and type. |
|
|
|
Attributes: |
|
node_id: Unique identifier for this node |
|
node_type: Type of node ('input', 'hidden', 'recurrent', or 'output') |
|
activation: Activation function ('tanh', 'relu', 'sigmoid', or 'linear') |
|
""" |
|
node_id: int |
|
node_type: str |
|
activation: str |
|
|
|
@dataclass |
|
class ConnectionGene: |
|
"""Connection gene containing connection properties. |
|
|
|
Attributes: |
|
source: ID of source node |
|
target: ID of target node |
|
weight: Connection weight |
|
enabled: Whether connection is enabled |
|
innovation: Unique innovation number for this connection |
|
""" |
|
source: int |
|
target: int |
|
weight: float |
|
enabled: bool = True |
|
innovation: int = 0 |
|
|
|
class Genome: |
|
"""NEAT Genome implementation. |
|
|
|
A genome represents a neural network as a collection of node and connection genes. |
|
The network topology can be modified through mutation operations. |
|
|
|
Attributes: |
|
input_size: Number of input nodes |
|
output_size: Number of output nodes |
|
node_genes: Dictionary mapping node IDs to NodeGene objects |
|
connection_genes: List of ConnectionGene objects |
|
key: Random key for reproducible randomness |
|
innovation_number: Counter for assigning unique innovation numbers |
|
""" |
|
|
|
def __init__(self, input_size: int, output_size: int): |
|
"""Initialize genome with specified number of inputs and outputs. |
|
|
|
Args: |
|
input_size: Number of input nodes |
|
output_size: Number of output nodes (must be 3 for volleyball) |
|
""" |
|
self.input_size = input_size |
|
self.output_size = output_size |
|
self.node_genes: Dict[int, NodeGene] = {} |
|
self.connection_genes: List[ConnectionGene] = [] |
|
|
|
|
|
timestamp = int(time.time() * 1000) |
|
self.key = jrandom.PRNGKey(hash((input_size, output_size, timestamp)) % (2**32)) |
|
|
|
|
|
self.innovation_number = 0 |
|
|
|
|
|
self._init_minimal() |
|
|
|
def _init_minimal(self): |
|
"""Initialize minimal feed-forward network structure. |
|
|
|
Network structure: |
|
- Input nodes [0-7]: Game state inputs |
|
- Hidden layer 1 [8-15]: First processing layer (8 nodes) |
|
- Hidden layer 2 [16-23]: Second processing layer (8 nodes) |
|
- Output nodes [24-26]: Action outputs (left, right, jump) |
|
|
|
Using larger initial weights for faster learning: |
|
- Input->Hidden1: N(0, 2.0) for strong initial responses |
|
- Hidden1->Hidden2: N(0, 2.0) for feature processing |
|
- Hidden2->Output: N(0, 4.0) for decisive actions |
|
""" |
|
|
|
for i in range(8): |
|
self.node_genes[i] = NodeGene( |
|
node_id=i, |
|
node_type='input', |
|
activation='linear' |
|
) |
|
|
|
|
|
hidden1_size = 8 |
|
hidden1_start = 8 |
|
for i in range(hidden1_size): |
|
node_id = hidden1_start + i |
|
self.node_genes[node_id] = NodeGene( |
|
node_id=node_id, |
|
node_type='hidden', |
|
activation='relu' |
|
) |
|
|
|
|
|
for input_id in range(8): |
|
weight = float(jrandom.normal(self.key) * 2.0) |
|
self.connection_genes.append(ConnectionGene( |
|
source=input_id, |
|
target=node_id, |
|
weight=weight, |
|
enabled=True, |
|
innovation=self.innovation_number |
|
)) |
|
self.innovation_number += 1 |
|
|
|
|
|
hidden2_size = 8 |
|
hidden2_start = hidden1_start + hidden1_size |
|
for i in range(hidden2_size): |
|
node_id = hidden2_start + i |
|
self.node_genes[node_id] = NodeGene( |
|
node_id=node_id, |
|
node_type='hidden', |
|
activation='relu' |
|
) |
|
|
|
|
|
for h1_id in range(hidden1_start, hidden1_start + hidden1_size): |
|
weight = float(jrandom.normal(self.key) * 2.0) |
|
self.connection_genes.append(ConnectionGene( |
|
source=h1_id, |
|
target=node_id, |
|
weight=weight, |
|
enabled=True, |
|
innovation=self.innovation_number |
|
)) |
|
self.innovation_number += 1 |
|
|
|
|
|
output_start = hidden2_start + hidden2_size |
|
for i in range(self.output_size): |
|
node_id = output_start + i |
|
self.node_genes[node_id] = NodeGene( |
|
node_id=node_id, |
|
node_type='output', |
|
activation='tanh' |
|
) |
|
|
|
|
|
for h2_id in range(hidden2_start, hidden2_start + hidden2_size): |
|
weight = float(jrandom.normal(self.key) * 4.0) |
|
self.connection_genes.append(ConnectionGene( |
|
source=h2_id, |
|
target=node_id, |
|
weight=weight, |
|
enabled=True, |
|
innovation=self.innovation_number |
|
)) |
|
self.innovation_number += 1 |
|
|
|
def mutate(self, config: Dict): |
|
"""Mutate the genome by modifying weights and network structure. |
|
|
|
Args: |
|
config: Dictionary containing mutation parameters: |
|
- weight_mutation_rate: Probability of mutating each weight |
|
- weight_mutation_power: Standard deviation for weight mutations |
|
- add_node_rate: Probability of adding a new node |
|
- add_connection_rate: Probability of adding a new connection |
|
""" |
|
|
|
for conn in self.connection_genes: |
|
if jrandom.uniform(self.key) < config['weight_mutation_rate']: |
|
|
|
self.key, subkey = jrandom.split(self.key) |
|
|
|
conn.weight += float(jrandom.normal(subkey) * config['weight_mutation_power']) |
|
|
|
|
|
if config['add_node_rate'] > 0: |
|
if jrandom.uniform(self.key) < config['add_node_rate']: |
|
self._add_node() |
|
|
|
|
|
if config['add_connection_rate'] > 0: |
|
if jrandom.uniform(self.key) < config['add_connection_rate']: |
|
self._add_connection() |
|
|
|
def _add_node(self): |
|
"""Add a new node by splitting an existing connection.""" |
|
if not self.connection_genes: |
|
return |
|
|
|
|
|
conn_to_split = np.random.choice(self.connection_genes) |
|
conn_to_split.enabled = False |
|
|
|
|
|
new_node_id = max(self.node_genes.keys()) + 1 |
|
self.node_genes[new_node_id] = NodeGene( |
|
node_id=new_node_id, |
|
node_type='hidden', |
|
activation='relu' |
|
) |
|
|
|
|
|
self.connection_genes.extend([ |
|
ConnectionGene( |
|
source=conn_to_split.source, |
|
target=new_node_id, |
|
weight=1.0, |
|
enabled=True, |
|
innovation=self.innovation_number |
|
), |
|
ConnectionGene( |
|
source=new_node_id, |
|
target=conn_to_split.target, |
|
weight=conn_to_split.weight, |
|
enabled=True, |
|
innovation=self.innovation_number + 1 |
|
) |
|
]) |
|
self.innovation_number += 2 |
|
|
|
def _add_connection(self): |
|
"""Add a new connection between two unconnected nodes.""" |
|
|
|
existing_connections = {(c.source, c.target) for c in self.connection_genes} |
|
possible_connections = [] |
|
|
|
for source in self.node_genes: |
|
for target in self.node_genes: |
|
|
|
if (source, target) in existing_connections: |
|
continue |
|
|
|
|
|
if self.node_genes[source].node_type != 'recurrent' and \ |
|
self.would_create_cycle(source, target): |
|
continue |
|
|
|
possible_connections.append((source, target)) |
|
|
|
if possible_connections: |
|
|
|
source, target = random.choice(possible_connections) |
|
|
|
|
|
weight = float(jrandom.normal(self.key) * 1.0) |
|
self.connection_genes.append(ConnectionGene( |
|
source=source, |
|
target=target, |
|
weight=weight, |
|
enabled=True, |
|
innovation=self.innovation_number |
|
)) |
|
self.innovation_number += 1 |
|
|
|
def would_create_cycle(self, source: int, target: int) -> bool: |
|
"""Check if adding connection would create cycle in network. |
|
|
|
Args: |
|
source: Source node ID |
|
target: Target node ID |
|
|
|
Returns: |
|
True if connection would create cycle, False otherwise |
|
""" |
|
|
|
if self.node_genes[source].node_type == 'recurrent' or \ |
|
self.node_genes[target].node_type == 'recurrent': |
|
return False |
|
|
|
|
|
visited = set() |
|
|
|
def dfs(node: int) -> bool: |
|
if node == source: |
|
return True |
|
if node in visited: |
|
return False |
|
|
|
visited.add(node) |
|
for conn in self.connection_genes: |
|
if conn.source == node and conn.enabled: |
|
if dfs(conn.target): |
|
return True |
|
return False |
|
|
|
return dfs(target) |
|
|
|
def add_node_between(self, source: int, target: int): |
|
"""Add a new node between two nodes, splitting an existing connection. |
|
|
|
Args: |
|
source: Source node ID |
|
target: Target node ID |
|
""" |
|
|
|
for conn in self.connection_genes: |
|
if conn.source == source and conn.target == target and conn.enabled: |
|
conn.enabled = False |
|
|
|
|
|
new_id = max(self.node_genes.keys()) + 1 |
|
self.node_genes[new_id] = NodeGene( |
|
node_id=new_id, |
|
node_type='hidden', |
|
activation='relu' |
|
) |
|
|
|
|
|
self.connection_genes.extend([ |
|
ConnectionGene( |
|
source=source, |
|
target=new_id, |
|
weight=1.0, |
|
enabled=True, |
|
innovation=self.innovation_number |
|
), |
|
ConnectionGene( |
|
source=new_id, |
|
target=target, |
|
weight=conn.weight, |
|
enabled=True, |
|
innovation=self.innovation_number + 1 |
|
) |
|
]) |
|
self.innovation_number += 2 |
|
break |
|
|
|
def add_connection(self, source: int, target: int, weight: Optional[float] = None) -> bool: |
|
"""Add a new connection between two nodes. |
|
|
|
Args: |
|
source: Source node ID |
|
target: Target node ID |
|
weight: Optional connection weight. If None, a random weight is generated. |
|
|
|
Returns: |
|
True if connection was added, False if invalid or already exists |
|
""" |
|
|
|
if any(c.source == source and c.target == target for c in self.connection_genes): |
|
return False |
|
|
|
|
|
if source not in self.node_genes or target not in self.node_genes: |
|
return False |
|
|
|
|
|
if source >= target: |
|
return False |
|
|
|
|
|
if weight is None: |
|
weight = float(jrandom.normal(self.key) * 1.0) |
|
|
|
|
|
self.connection_genes.append(ConnectionGene( |
|
source=source, |
|
target=target, |
|
weight=weight, |
|
enabled=True, |
|
innovation=self.innovation_number |
|
)) |
|
self.innovation_number += 1 |
|
return True |
|
|
|
def crossover(self, other: 'Genome', key: jnp.ndarray) -> 'Genome': |
|
"""Perform crossover between two genomes. |
|
|
|
Args: |
|
other: Other parent genome |
|
key: JAX PRNG key |
|
|
|
Returns: |
|
Child genome |
|
""" |
|
|
|
child = Genome(self.input_size, self.output_size) |
|
|
|
|
|
for node_id in self.node_genes: |
|
if node_id in other.node_genes: |
|
|
|
if jrandom.uniform(key) < 0.5: |
|
child.node_genes[node_id] = self.node_genes[node_id] |
|
else: |
|
child.node_genes[node_id] = other.node_genes[node_id] |
|
else: |
|
|
|
child.node_genes[node_id] = self.node_genes[node_id] |
|
|
|
|
|
for conn in self.connection_genes: |
|
if conn.innovation in [c.innovation for c in other.connection_genes]: |
|
|
|
other_conn = next(c for c in other.connection_genes if c.innovation == conn.innovation) |
|
if jrandom.uniform(key) < 0.5: |
|
child.connection_genes.append(ConnectionGene( |
|
source=conn.source, |
|
target=conn.target, |
|
weight=conn.weight, |
|
enabled=conn.enabled, |
|
innovation=conn.innovation |
|
)) |
|
else: |
|
child.connection_genes.append(ConnectionGene( |
|
source=other_conn.source, |
|
target=other_conn.target, |
|
weight=other_conn.weight, |
|
enabled=other_conn.enabled, |
|
innovation=other_conn.innovation |
|
)) |
|
else: |
|
|
|
child.connection_genes.append(ConnectionGene( |
|
source=conn.source, |
|
target=conn.target, |
|
weight=conn.weight, |
|
enabled=conn.enabled, |
|
innovation=conn.innovation |
|
)) |
|
|
|
return child |
|
|
|
def clone(self) -> 'Genome': |
|
"""Create a copy of this genome. |
|
|
|
Returns: |
|
Copy of genome |
|
""" |
|
clone = Genome(self.input_size, self.output_size) |
|
clone.node_genes = self.node_genes.copy() |
|
clone.connection_genes = [ConnectionGene(**conn.__dict__) for conn in self.connection_genes] |
|
return clone |
|
|
|
@property |
|
def n_nodes(self) -> int: |
|
"""Get total number of nodes in the genome.""" |
|
return len(self.node_genes) |
|
|