"""Train NEAT networks to play volleyball using hardware acceleration when available.""" |
import jax |
import jax.numpy as jnp |
from jax import random |
from evojax.task.slimevolley import SlimeVolley |
from typing import List, Tuple, Dict |
import numpy as np |
import time |
from PIL import Image |
import io |
import os |
try: |
os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' |
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform' |
print(f"JAX devices available: {jax.devices()}") |
print(f"Using device: {jax.devices()[0].platform.upper()}") |
except Exception as e: |
print(f"Note: Using CPU - {str(e)}") |
class NodeGene: |
"""A gene representing a node in the neural network.""" |
def __init__(self, id: int, node_type: str, activation: str = 'tanh'): |
self.id = id |
self.type = node_type |
self.activation = activation |
seed = abs(hash(f"node_{id}")) % (2**32 - 1) |
key = random.PRNGKey(seed) |
self.bias = float(random.normal(key, shape=()) * 0.1) |
class ConnectionGene: |
"""A gene representing a connection between nodes.""" |
def __init__(self, source: int, target: int, weight: float = None, enabled: bool = True): |
self.source = source |
self.target = target |
self.enabled = enabled |
self.innovation = hash((source, target)) |
if weight is None: |
seed = abs(hash(f"conn_{source}_{target}")) % (2**32 - 1) |
key = random.PRNGKey(seed) |
weight = float(random.normal(key, shape=()) * 0.1) |
self.weight = weight |
class Genome: |
def __init__(self, n_inputs: int, n_outputs: int): |
self.node_genes = {i: NodeGene(i, 'input') for i in range(n_inputs)} |
n_outputs = 3 |
for i in range(n_outputs): |
self.node_genes[n_inputs + i] = NodeGene(n_inputs + i, 'output') |
self.connection_genes: List[ConnectionGene] = [] |
seed = int(time.time() * 1000) % (2**32 - 1) |
master_key = random.PRNGKey(seed) |
for i in range(n_inputs): |
for j in range(n_outputs): |
master_key, key = random.split(master_key) |
if random.uniform(key, shape=()) < 0.7: |
master_key, key = random.split(master_key) |
weight = float(random.normal(key, shape=()) * 0.5) |
self.connection_genes.append( |
ConnectionGene(i, n_inputs + j, weight=weight) |
) |
master_key, key = random.split(master_key) |
n_hidden = int(random.randint(key, (), 1, 4)) |
hidden_start = n_inputs + n_outputs |
for i in range(n_hidden): |
node_id = hidden_start + i |
self.node_genes[node_id] = NodeGene(node_id, 'hidden') |
for j in range(n_inputs): |
master_key, key = random.split(master_key) |
if random.uniform(key, shape=()) < 0.5: |
master_key, key = random.split(master_key) |
weight = float(random.normal(key, shape=()) * 0.5) |
self.connection_genes.append( |
ConnectionGene(j, node_id, weight=weight) |
) |
for j in range(n_outputs): |
master_key, key = random.split(master_key) |
if random.uniform(key, shape=()) < 0.5: |
master_key, key = random.split(master_key) |
weight = float(random.normal(key, shape=()) * 0.5) |
self.connection_genes.append( |
ConnectionGene(node_id, n_inputs + j, weight=weight) |
) |
def mutate(self, config: Dict): |
seed = int(time.time() * 1000) % (2**32 - 1) |
key = random.PRNGKey(seed) |
for conn in self.connection_genes: |
key, subkey = random.split(key) |
if random.uniform(subkey, shape=()) < config['weight_mutation_rate']: |
key, subkey = random.split(key) |
if random.uniform(subkey, shape=()) < 0.1: |
key, subkey = random.split(key) |
conn.weight = float(random.normal(subkey, shape=()) * 0.5) |
else: |
key, subkey = random.split(key) |
conn.weight += float(random.normal(subkey) * config['weight_mutation_power']) |
for node in self.node_genes.values(): |
key, subkey = random.split(key) |
if random.uniform(subkey, shape=()) < 0.1: |
key, subkey = random.split(key) |
node.bias += float(random.normal(subkey) * 0.1) |
key, subkey = random.split(key) |
if random.uniform(subkey, shape=()) < config['add_node_rate']: |
if self.connection_genes: |
conn = np.random.choice(self.connection_genes) |
new_id = max(self.node_genes.keys()) + 1 |
self.node_genes[new_id] = NodeGene(new_id, 'hidden') |
key, subkey = random.split(key) |
weight1 = float(random.normal(subkey, shape=()) * 0.5) |
key, subkey = random.split(key) |
weight2 = float(random.normal(subkey, shape=()) * 0.5) |
self.connection_genes.append( |
ConnectionGene(conn.source, new_id, weight=weight1) |
) |
self.connection_genes.append( |
ConnectionGene(new_id, conn.target, weight=weight2) |
) |
conn.enabled = False |
key, subkey = random.split(key) |
if random.uniform(subkey, shape=()) < config['add_connection_rate']: |
nodes = list(self.node_genes.keys()) |
for _ in range(10): |
source = np.random.choice(nodes) |
target = np.random.choice(nodes) |
if source < target: |
if not any(c.source == source and c.target == target |
for c in self.connection_genes): |
key, subkey = random.split(key) |
weight = float(random.normal(subkey, shape=()) * 0.5) |
self.connection_genes.append( |
ConnectionGene(source, target, weight=weight) |
) |
break |
class Network: |
def __init__(self, genome: Genome): |
self.genome = genome |
self.input_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'input'], key=lambda x: x.id) |
self.hidden_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'hidden'], key=lambda x: x.id) |
self.output_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'output'], key=lambda x: x.id) |
assert len(self.output_nodes) == 3, f"Expected 3 output nodes, got {len(self.output_nodes)}" |
def forward(self, x: jnp.ndarray) -> jnp.ndarray: |
if len(x.shape) == 1: |
x = jnp.expand_dims(x, 0) |
batch_size = x.shape[0] |
values = {} |
for node in self.genome.node_genes.values(): |
values[node.id] = jnp.zeros((batch_size,)) |
values[node.id] = values[node.id] + node.bias |
for i, node in enumerate(self.input_nodes): |
values[node.id] = x[:, i] |
for node in self.hidden_nodes + self.output_nodes: |
total = jnp.zeros((batch_size,)) |
total = total + node.bias |
for conn in self.genome.connection_genes: |
if conn.enabled and conn.target == node.id: |
total = total + values[conn.source] * conn.weight |
values[node.id] = jnp.tanh(total) |
outputs = [] |
for node in self.output_nodes: |
outputs.append(values[node.id]) |
return jnp.stack(outputs, axis=-1) |
def evaluate_parallel(networks: List[Network], env: SlimeVolley, batch_size: int = 8) -> List[float]: |
"""Evaluate multiple networks in parallel using JAX's vectorization.""" |
total_networks = len(networks) |
fitness_scores = [] |
for i in range(0, total_networks, batch_size): |
batch = networks[i:i + batch_size] |
batch_size_actual = len(batch) |
seed = int(time.time() * 1000) % (2**32 - 1) |
key = random.PRNGKey(seed) |
states = env.reset(key) |
total_rewards = np.zeros(batch_size_actual) |
for step in range(1000): |
observations = states.obs / 10.0 |
actions = np.stack([ |
net.forward(obs[None, :]) |
for net, obs in zip(batch, observations) |
]) |
thresholds = np.array([0.5, 0.5, 0.5]) |
binary_actions = (actions > thresholds).astype(np.float32) |
key, subkey = random.split(key) |
next_states, rewards, dones = env.step(states, binary_actions) |
total_rewards += np.array([float(r) for r in rewards]) |
states = next_states |
if np.all(dones): |
break |
fitness_scores.extend(list(total_rewards)) |
return fitness_scores |
def create_next_generation(population: List[Network], fitness_scores: List[float], config: Dict): |
"""Create the next generation of networks based on the current population and fitness scores.""" |
next_population = [] |
n_elite = max(2, int(0.2 * len(population))) |
next_population.extend(population[:n_elite]) |
n_top = max(5, int(0.5 * len(population))) |
while len(next_population) < len(population): |
tournament_size = 3 |
candidates = np.random.choice(population[:n_top], tournament_size, replace=False) |
parent = max(candidates, key=lambda x: fitness_scores[population.index(x)]) |
child = Network(parent.genome) |
child.genome.mutate(config) |
next_population.append(child) |
return next_population |
def record_gameplay(network: Network, env: SlimeVolley, filename: str = 'gameplay.gif', max_steps: int = 1000): |
"""Record a game played by the network and save it as a GIF.""" |
frames = [] |
seed = int(time.time() * 1000) % (2**32 - 1) |
key = random.PRNGKey(seed) |
state = env.reset(key) |
done = False |
steps = 0 |
while not done and steps < max_steps: |
frame = env.render(state) |
frames.append(frame) |
obs = state.obs[None, :] / 10.0 |
raw_action = network.forward(obs) |
thresholds = jnp.array([0.5, 0.5, 0.5]) |
binary_action = (raw_action > thresholds).astype(jnp.float32) |
both_active = jnp.logical_and(binary_action[:, 0] > 0, binary_action[:, 1] > 0) |
prefer_left = raw_action[:, 0] > raw_action[:, 1] |
binary_action = binary_action.at[:, 0].set( |
jnp.where(both_active, prefer_left.astype(jnp.float32), binary_action[:, 0]) |
) |
binary_action = binary_action.at[:, 1].set( |
jnp.where(both_active, (~prefer_left).astype(jnp.float32), binary_action[:, 1]) |
) |
key, subkey = random.split(key) |
state, reward, done = env.step(state, binary_action) |
steps += 1 |
if frames: |
frames[0].save( |
filename, |
save_all=True, |
append_images=frames[1:], |
duration=50, |
loop=0 |
) |
print(f"Gameplay recorded and saved to {filename}") |
else: |
print("No frames were recorded") |
def main(): |
"""Main training loop with hardware acceleration when available.""" |
env = SlimeVolley(max_steps=1000) |
config = { |
'population_size': 64, |
'batch_size': 8, |
'weight_mutation_rate': 0.95, |
'weight_mutation_power': 4.0, |
'add_node_rate': 0.0, |
'add_connection_rate': 0.0, |
} |
print("\nTraining Configuration:") |
print(f"Population Size: {config['population_size']}") |
print(f"Batch Size: {config['batch_size']}") |
print(f"Mutation Rate: {config['weight_mutation_rate']}") |
print("-" * 40) |
population = [ |
Network(Genome(n_inputs=12, n_outputs=3)) |
for _ in range(config['population_size']) |
] |
best_fitness = float('-inf') |
best_network = None |
for generation in range(1000): |
start_time = time.time() |
print(f"\nGeneration {generation}") |
fitness_scores = evaluate_parallel( |
population, |
env, |
batch_size=config['batch_size'] |
) |
max_fitness = max(fitness_scores) |
if max_fitness > best_fitness: |
best_idx = fitness_scores.index(max_fitness) |
best_fitness = max_fitness |
best_network = population[best_idx] |
print(f"New best fitness: {best_fitness:.2f}") |
if max_fitness > best_fitness + 2.0: |
record_gameplay(best_network, env, f"best_gen_{generation}.gif") |
if best_fitness > 8.0: |
print(f"Target fitness reached: {best_fitness:.2f}") |
break |
population = create_next_generation( |
population, |
fitness_scores, |
config |
) |
if generation % 5 == 0: |
gen_time = time.time() - start_time |
print(f"\nGeneration {generation} Stats:") |
print(f"Best Fitness: {max_fitness:.2f}") |
print(f"Average Fitness: {np.mean(fitness_scores):.2f}") |
print(f"Generation Time: {gen_time:.2f}s") |
print("\nTraining complete!") |
print(f"Best fitness achieved: {best_fitness:.2f}") |
if best_network: |
record_gameplay(best_network, env, "final_gameplay.gif") |
if __name__ == '__main__': |
main() |