import jax
import jax.numpy as jnp
from jax import random
from evojax.task.slimevolley import SlimeVolley
from typing import List, Tuple, Dict
import numpy as np
import time
class NodeGene:
def __init__(self, id: int, node_type: str, activation: str = 'tanh'): = id
self.type = node_type # 'input', 'hidden', or 'output'
self.activation = activation
# Use both id and timestamp for randomization
timestamp = int(time.time() * 1000)
key = random.PRNGKey(hash((id, timestamp)) % (2**32))
self.bias = float(random.normal(key, shape=()) * 0.1) # Small random bias
class ConnectionGene:
def __init__(self, source: int, target: int, weight: float = None, enabled: bool = True):
self.source = source = target
# Use source, target, and timestamp for randomization
timestamp = int(time.time() * 1000)
key = random.PRNGKey(hash((source, target, timestamp)) % (2**32))
if weight is None:
key, subkey = random.split(key)
weight = float(random.normal(subkey, shape=()) * 0.1) # Small random weight
self.weight = weight
self.enabled = enabled
self.innovation = hash((source, target))
class Genome:
def __init__(self, n_inputs: int, n_outputs: int):
# Create input nodes (0 to n_inputs-1)
self.node_genes = {i: NodeGene(i, 'input') for i in range(n_inputs)}
# Create exactly 3 output nodes for left, right, jump
n_outputs = 3 # Force exactly 3 outputs
for i in range(n_outputs):
self.node_genes[n_inputs + i] = NodeGene(n_inputs + i, 'output')
self.connection_genes: List[ConnectionGene] = []
# Initialize with randomized connections using unique keys
timestamp = int(time.time() * 1000)
master_key = random.PRNGKey(hash((n_inputs, n_outputs, timestamp)) % (2**32))
# Add direct connections with random weights
for i in range(n_inputs):
for j in range(n_outputs):
master_key, key = random.split(master_key)
if random.uniform(key, shape=()) < 0.7: # 70% chance of connection
master_key, key = random.split(master_key)
weight = float(random.normal(key, shape=()) * 0.5) # Larger initial weights
ConnectionGene(i, n_inputs + j, weight=weight)
# Add hidden nodes with random connections
master_key, key = random.split(master_key)
n_hidden = int(random.randint(key, (), 1, 4)) # Random number of hidden nodes
hidden_start = n_inputs + n_outputs
for i in range(n_hidden):
node_id = hidden_start + i
self.node_genes[node_id] = NodeGene(node_id, 'hidden')
# Connect random inputs to this hidden node
for j in range(n_inputs):
master_key, key = random.split(master_key)
if random.uniform(key, shape=()) < 0.5:
master_key, key = random.split(master_key)
weight = float(random.normal(key, shape=()) * 0.5)
ConnectionGene(j, node_id, weight=weight)
# Connect this hidden node to random outputs
for j in range(n_outputs):
master_key, key = random.split(master_key)
if random.uniform(key, shape=()) < 0.5:
master_key, key = random.split(master_key)
weight = float(random.normal(key, shape=()) * 0.5)
ConnectionGene(node_id, n_inputs + j, weight=weight)
def mutate(self, config: Dict):
key = random.PRNGKey(0)
# Mutate connection weights
for conn in self.connection_genes:
key, subkey = random.split(key)
if random.uniform(subkey, shape=()) < config['weight_mutation_rate']:
key, subkey = random.split(key)
# Sometimes reset weight completely
if random.uniform(subkey, shape=()) < 0.1:
key, subkey = random.split(key)
conn.weight = float(random.normal(subkey, shape=()) * 0.5)
# Otherwise adjust existing weight
key, subkey = random.split(key)
conn.weight += float(random.normal(subkey) * config['weight_mutation_power'])
# Mutate node biases
for node in self.node_genes.values():
key, subkey = random.split(key)
if random.uniform(subkey, shape=()) < 0.1: # 10% chance to mutate bias
key, subkey = random.split(key)
node.bias += float(random.normal(subkey) * 0.1)
# Add new node
key, subkey = random.split(key)
if random.uniform(subkey, shape=()) < config['add_node_rate']:
if self.connection_genes:
# Choose random connection to split
conn = np.random.choice(self.connection_genes)
new_id = max(self.node_genes.keys()) + 1
# Create new node with random bias
self.node_genes[new_id] = NodeGene(new_id, 'hidden')
# Create two new connections with some randomization
key, subkey = random.split(key)
weight1 = float(random.normal(subkey, shape=()) * 0.5)
key, subkey = random.split(key)
weight2 = float(random.normal(subkey, shape=()) * 0.5)
ConnectionGene(conn.source, new_id, weight=weight1)
ConnectionGene(new_id,, weight=weight2)
# Disable old connection
conn.enabled = False
# Add new connection
key, subkey = random.split(key)
if random.uniform(subkey, shape=()) < config['add_connection_rate']:
# Get all possible nodes
nodes = list(self.node_genes.keys())
for _ in range(10): # Try 10 times to find valid connection
source = np.random.choice(nodes)
target = np.random.choice(nodes)
# Ensure forward propagation (source id < target id)
if source < target:
# Check if connection already exists
if not any(c.source == source and == target
for c in self.connection_genes):
key, subkey = random.split(key)
weight = float(random.normal(subkey, shape=()) * 0.5)
ConnectionGene(source, target, weight=weight)
class Network:
def __init__(self, genome: Genome):
self.genome = genome
# Sort nodes by ID to ensure consistent ordering
self.input_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'input'], key=lambda x:
self.hidden_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'hidden'], key=lambda x:
self.output_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'output'], key=lambda x:
# Verify we have exactly 3 output nodes
assert len(self.output_nodes) == 3, f"Expected 3 output nodes, got {len(self.output_nodes)}"
def forward(self, x: jnp.ndarray) -> jnp.ndarray:
# Ensure input is 2D with shape (batch_size, input_dim)
if len(x.shape) == 1:
x = jnp.expand_dims(x, 0)
batch_size = x.shape[0]
# Initialize node values
values = {}
for node in self.genome.node_genes.values():
values[] = jnp.zeros((batch_size,))
values[] = values[] + node.bias
# Set input values
for i, node in enumerate(self.input_nodes):
values[] = x[:, i]
# Process nodes in order
for node in self.hidden_nodes + self.output_nodes:
# Sum incoming connections
total = jnp.zeros((batch_size,))
total = total + node.bias
for conn in self.genome.connection_genes:
if conn.enabled and ==
total = total + values[conn.source] * conn.weight
# Apply activation
values[] = jnp.tanh(total)
# Get output values and ensure shape (batch_size, 3)
outputs = []
for node in self.output_nodes:
# Stack along new axis to get (batch_size, 3)
return jnp.stack(outputs, axis=-1)
def evaluate_network(network: Network, env: SlimeVolley, n_episodes: int = 10) -> float:
total_reward = 0.0
# Generate a unique key for this evaluation
timestamp = int(time.time() * 1000)
network_id = id(network)
master_key = random.PRNGKey(hash((network_id, timestamp)) % (2**32))
for episode in range(n_episodes):
# Reset environment with proper key shape
master_key, reset_key = random.split(master_key)
state = env.reset(reset_key[None, :]) # Add batch dimension
done = False
episode_reward = 0.0
steps = 0
while not done and steps < 1000: # Add step limit
# Get observation and normalize
obs = state.obs[None, :] / 10.0 # Add batch dimension and scale inputs
# Get action from network (shape: batch_size, 3)
raw_action = network.forward(obs)
# Convert to binary actions using thresholds
thresholds = jnp.array([0.3, 0.3, 0.4]) # left, right, jump
binary_action = (raw_action > thresholds).astype(jnp.float32)
# Prevent simultaneous left/right using logical operations
both_active = jnp.logical_and(binary_action[:, 0] > 0, binary_action[:, 1] > 0)
prefer_left = raw_action[:, 0] > raw_action[:, 1]
# Update binary action based on preference
binary_action =[:, 0].set(
jnp.where(both_active, prefer_left.astype(jnp.float32), binary_action[:, 0])
binary_action =[:, 1].set(
jnp.where(both_active, (~prefer_left).astype(jnp.float32), binary_action[:, 1])
# Step environment
master_key, step_key = random.split(master_key)
next_state, reward, done = env.step(state, binary_action) # Already batched
# Process reward and done flag
if isinstance(reward, jnp.ndarray):
reward = float(jnp.reshape(reward, (-1,))[0]) # Get first element if batched
if isinstance(done, jnp.ndarray):
done = bool(jnp.reshape(done, (-1,))[0]) # Convert to Python bool
# Add small reward for movement to encourage exploration
any_movement = jnp.any(binary_action[:, :2] > 0)
movement_reward = 0.1 if bool(any_movement) else 0.0
# Add small reward for keeping ball in play
ball_height = float(jnp.reshape(next_state.obs[1], (-1,))[0]) if hasattr(next_state.obs, '__getitem__') else 0.0
height_reward = 0.1 if ball_height > 0.5 else 0.0
# Add reward for ball position and velocity
ball_x = float(jnp.reshape(next_state.obs[4], (-1,))[0]) # Ball x position
ball_vx = float(jnp.reshape(next_state.obs[6], (-1,))[0]) # Ball x velocity
position_reward = 0.2 if ball_x > 0 else 0.0 # Reward for keeping ball on opponent's side
velocity_reward = 0.1 if ball_vx > 0 else 0.0 # Reward for hitting ball towards opponent
# Calculate step reward with more emphasis on game outcome
step_reward = reward * 2.0 # Double the importance of winning/losing
bonus_reward = movement_reward + height_reward + position_reward + velocity_reward
total_step_reward = step_reward + bonus_reward * 0.5 # Scale down bonus rewards
episode_reward += total_step_reward
state = next_state
steps += 1
# Early termination bonus
if done and reward > 0: # Won the point
episode_reward += 10.0
total_reward += episode_reward
return total_reward / n_episodes
def main():
# Initialize environment
env = SlimeVolley()
# NEAT configuration
config = {
'population_size': 50, # Smaller population for faster iteration
'weight_mutation_rate': 0.8,
'weight_mutation_power': 0.3, # Increased for more exploration
'add_node_rate': 0.3,
'add_connection_rate': 0.5,
# Create initial population
population = [
Network(Genome(n_inputs=12, n_outputs=3))
for _ in range(config['population_size'])
best_fitness = float('-inf')
generations_without_improvement = 0
# Evolution loop
for generation in range(500): # More generations
print(f"\nGeneration {generation}")
print("-" * 20)
# Evaluate population
fitnesses = []
for i, net in enumerate(population):
fitness = evaluate_network(net, env)
print(f"Network {i}: Fitness = {fitness:.2f}")
if fitness > best_fitness:
best_fitness = fitness
generations_without_improvement = 0
print(f"New best fitness: {best_fitness:.2f}")
# Check for improvement
generations_without_improvement += 1
if generations_without_improvement > 20:
print("No improvement for 20 generations, increasing mutation rates")
config['weight_mutation_rate'] = min(1.0, config['weight_mutation_rate'] * 1.2)
config['weight_mutation_power'] = min(0.5, config['weight_mutation_power'] * 1.2)
generations_without_improvement = 0
# Print progress
avg_fitness = sum(fitnesses) / len(fitnesses)
print(f"\nBest fitness: {best_fitness:.2f}")
print(f"Average fitness: {avg_fitness:.2f}")
# Selection and reproduction
new_population = []
sorted_indices = np.argsort(fitnesses)[::-1] # Best to worst
# Keep best networks
n_elite = 5 # Fewer elites
new_population.extend([population[i] for i in sorted_indices[:n_elite]])
print(f"Keeping top {n_elite} networks")
# Create offspring from best networks
while len(new_population) < config['population_size']:
# Tournament selection
tournament_size = 5
tournament = np.random.choice(sorted_indices[:20], tournament_size, replace=False)
parent_idx = tournament[np.argmax([fitnesses[i] for i in tournament])]
parent = population[parent_idx]
# Create offspring
child_genome = Genome(12, 3)
child_genome.node_genes = parent.genome.node_genes.copy()
child_genome.connection_genes = parent.genome.connection_genes.copy()
# Mutate child
# Add to new population
population = new_population
print(f"Created {len(population)} networks for next generation")
if __name__ == '__main__':
main() |