|
"""Train NEAT networks to play volleyball using hardware acceleration when available.""" |
|
|
|
import jax |
|
import jax.numpy as jnp |
|
from jax import random |
|
from evojax.task.slimevolley import SlimeVolley |
|
from typing import List, Tuple, Dict |
|
import numpy as np |
|
import time |
|
from PIL import Image |
|
import io |
|
import os |
|
|
|
|
|
try: |
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
|
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' |
|
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform' |
|
|
|
|
|
print(f"JAX devices available: {jax.devices()}") |
|
print(f"Using device: {jax.devices()[0].platform.upper()}") |
|
except Exception as e: |
|
print(f"Note: Using CPU - {str(e)}") |
|
|
|
class NodeGene: |
|
"""A gene representing a node in the neural network.""" |
|
def __init__(self, id: int, node_type: str, activation: str = 'tanh'): |
|
self.id = id |
|
self.type = node_type |
|
self.activation = activation |
|
|
|
|
|
seed = abs(hash(f"node_{id}")) % (2**32 - 1) |
|
key = random.PRNGKey(seed) |
|
self.bias = float(random.normal(key, shape=()) * 0.1) |
|
|
|
class ConnectionGene: |
|
"""A gene representing a connection between nodes.""" |
|
def __init__(self, source: int, target: int, weight: float = None, enabled: bool = True): |
|
self.source = source |
|
self.target = target |
|
self.enabled = enabled |
|
self.innovation = hash((source, target)) |
|
|
|
if weight is None: |
|
|
|
seed = abs(hash(f"conn_{source}_{target}")) % (2**32 - 1) |
|
key = random.PRNGKey(seed) |
|
weight = float(random.normal(key, shape=()) * 0.1) |
|
self.weight = weight |
|
|
|
class Genome: |
|
def __init__(self, n_inputs: int, n_outputs: int): |
|
|
|
self.node_genes = {i: NodeGene(i, 'input') for i in range(n_inputs)} |
|
|
|
|
|
n_outputs = 3 |
|
for i in range(n_outputs): |
|
self.node_genes[n_inputs + i] = NodeGene(n_inputs + i, 'output') |
|
|
|
self.connection_genes: List[ConnectionGene] = [] |
|
|
|
|
|
seed = int(time.time() * 1000) % (2**32 - 1) |
|
master_key = random.PRNGKey(seed) |
|
|
|
|
|
for i in range(n_inputs): |
|
for j in range(n_outputs): |
|
master_key, key = random.split(master_key) |
|
if random.uniform(key, shape=()) < 0.7: |
|
master_key, key = random.split(master_key) |
|
weight = float(random.normal(key, shape=()) * 0.5) |
|
self.connection_genes.append( |
|
ConnectionGene(i, n_inputs + j, weight=weight) |
|
) |
|
|
|
|
|
master_key, key = random.split(master_key) |
|
n_hidden = int(random.randint(key, (), 1, 4)) |
|
hidden_start = n_inputs + n_outputs |
|
|
|
for i in range(n_hidden): |
|
node_id = hidden_start + i |
|
self.node_genes[node_id] = NodeGene(node_id, 'hidden') |
|
|
|
|
|
for j in range(n_inputs): |
|
master_key, key = random.split(master_key) |
|
if random.uniform(key, shape=()) < 0.5: |
|
master_key, key = random.split(master_key) |
|
weight = float(random.normal(key, shape=()) * 0.5) |
|
self.connection_genes.append( |
|
ConnectionGene(j, node_id, weight=weight) |
|
) |
|
|
|
|
|
for j in range(n_outputs): |
|
master_key, key = random.split(master_key) |
|
if random.uniform(key, shape=()) < 0.5: |
|
master_key, key = random.split(master_key) |
|
weight = float(random.normal(key, shape=()) * 0.5) |
|
self.connection_genes.append( |
|
ConnectionGene(node_id, n_inputs + j, weight=weight) |
|
) |
|
|
|
def mutate(self, config: Dict): |
|
seed = int(time.time() * 1000) % (2**32 - 1) |
|
key = random.PRNGKey(seed) |
|
|
|
|
|
for conn in self.connection_genes: |
|
key, subkey = random.split(key) |
|
if random.uniform(subkey, shape=()) < config['weight_mutation_rate']: |
|
key, subkey = random.split(key) |
|
|
|
if random.uniform(subkey, shape=()) < 0.1: |
|
key, subkey = random.split(key) |
|
conn.weight = float(random.normal(subkey, shape=()) * 0.5) |
|
else: |
|
|
|
key, subkey = random.split(key) |
|
conn.weight += float(random.normal(subkey) * config['weight_mutation_power']) |
|
|
|
|
|
for node in self.node_genes.values(): |
|
key, subkey = random.split(key) |
|
if random.uniform(subkey, shape=()) < 0.1: |
|
key, subkey = random.split(key) |
|
node.bias += float(random.normal(subkey) * 0.1) |
|
|
|
|
|
key, subkey = random.split(key) |
|
if random.uniform(subkey, shape=()) < config['add_node_rate']: |
|
if self.connection_genes: |
|
|
|
conn = np.random.choice(self.connection_genes) |
|
new_id = max(self.node_genes.keys()) + 1 |
|
|
|
|
|
self.node_genes[new_id] = NodeGene(new_id, 'hidden') |
|
|
|
|
|
key, subkey = random.split(key) |
|
weight1 = float(random.normal(subkey, shape=()) * 0.5) |
|
key, subkey = random.split(key) |
|
weight2 = float(random.normal(subkey, shape=()) * 0.5) |
|
|
|
self.connection_genes.append( |
|
ConnectionGene(conn.source, new_id, weight=weight1) |
|
) |
|
self.connection_genes.append( |
|
ConnectionGene(new_id, conn.target, weight=weight2) |
|
) |
|
|
|
|
|
conn.enabled = False |
|
|
|
|
|
key, subkey = random.split(key) |
|
if random.uniform(subkey, shape=()) < config['add_connection_rate']: |
|
|
|
nodes = list(self.node_genes.keys()) |
|
for _ in range(10): |
|
source = np.random.choice(nodes) |
|
target = np.random.choice(nodes) |
|
|
|
|
|
if source < target: |
|
|
|
if not any(c.source == source and c.target == target |
|
for c in self.connection_genes): |
|
key, subkey = random.split(key) |
|
weight = float(random.normal(subkey, shape=()) * 0.5) |
|
self.connection_genes.append( |
|
ConnectionGene(source, target, weight=weight) |
|
) |
|
break |
|
|
|
class Network: |
|
def __init__(self, genome: Genome): |
|
self.genome = genome |
|
|
|
self.input_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'input'], key=lambda x: x.id) |
|
self.hidden_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'hidden'], key=lambda x: x.id) |
|
self.output_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'output'], key=lambda x: x.id) |
|
|
|
|
|
assert len(self.output_nodes) == 3, f"Expected 3 output nodes, got {len(self.output_nodes)}" |
|
|
|
def forward(self, x: jnp.ndarray) -> jnp.ndarray: |
|
|
|
if len(x.shape) == 1: |
|
x = jnp.expand_dims(x, 0) |
|
|
|
batch_size = x.shape[0] |
|
|
|
|
|
values = {} |
|
for node in self.genome.node_genes.values(): |
|
values[node.id] = jnp.zeros((batch_size,)) |
|
values[node.id] = values[node.id] + node.bias |
|
|
|
|
|
for i, node in enumerate(self.input_nodes): |
|
values[node.id] = x[:, i] |
|
|
|
|
|
for node in self.hidden_nodes + self.output_nodes: |
|
|
|
total = jnp.zeros((batch_size,)) |
|
total = total + node.bias |
|
|
|
for conn in self.genome.connection_genes: |
|
if conn.enabled and conn.target == node.id: |
|
total = total + values[conn.source] * conn.weight |
|
|
|
|
|
values[node.id] = jnp.tanh(total) |
|
|
|
|
|
outputs = [] |
|
for node in self.output_nodes: |
|
outputs.append(values[node.id]) |
|
|
|
|
|
return jnp.stack(outputs, axis=-1) |
|
|
|
def evaluate_parallel(networks: List[Network], env: SlimeVolley, batch_size: int = 8) -> List[float]: |
|
"""Evaluate multiple networks in parallel using JAX's vectorization.""" |
|
total_networks = len(networks) |
|
fitness_scores = [] |
|
|
|
for i in range(0, total_networks, batch_size): |
|
batch = networks[i:i + batch_size] |
|
batch_size_actual = len(batch) |
|
|
|
|
|
seed = int(time.time() * 1000) % (2**32 - 1) |
|
key = random.PRNGKey(seed) |
|
states = env.reset(key) |
|
total_rewards = np.zeros(batch_size_actual) |
|
|
|
|
|
for step in range(1000): |
|
|
|
observations = states.obs / 10.0 |
|
|
|
|
|
actions = np.stack([ |
|
net.forward(obs[None, :]) |
|
for net, obs in zip(batch, observations) |
|
]) |
|
|
|
|
|
thresholds = np.array([0.5, 0.5, 0.5]) |
|
binary_actions = (actions > thresholds).astype(np.float32) |
|
|
|
|
|
key, subkey = random.split(key) |
|
next_states, rewards, dones = env.step(states, binary_actions) |
|
total_rewards += np.array([float(r) for r in rewards]) |
|
states = next_states |
|
|
|
if np.all(dones): |
|
break |
|
|
|
fitness_scores.extend(list(total_rewards)) |
|
|
|
return fitness_scores |
|
|
|
def create_next_generation(population: List[Network], fitness_scores: List[float], config: Dict): |
|
"""Create the next generation of networks based on the current population and fitness scores.""" |
|
next_population = [] |
|
|
|
|
|
n_elite = max(2, int(0.2 * len(population))) |
|
next_population.extend(population[:n_elite]) |
|
|
|
|
|
n_top = max(5, int(0.5 * len(population))) |
|
while len(next_population) < len(population): |
|
|
|
tournament_size = 3 |
|
candidates = np.random.choice(population[:n_top], tournament_size, replace=False) |
|
parent = max(candidates, key=lambda x: fitness_scores[population.index(x)]) |
|
|
|
child = Network(parent.genome) |
|
child.genome.mutate(config) |
|
next_population.append(child) |
|
|
|
return next_population |
|
|
|
def record_gameplay(network: Network, env: SlimeVolley, filename: str = 'gameplay.gif', max_steps: int = 1000): |
|
"""Record a game played by the network and save it as a GIF.""" |
|
frames = [] |
|
|
|
|
|
seed = int(time.time() * 1000) % (2**32 - 1) |
|
key = random.PRNGKey(seed) |
|
state = env.reset(key) |
|
done = False |
|
steps = 0 |
|
|
|
while not done and steps < max_steps: |
|
|
|
frame = env.render(state) |
|
frames.append(frame) |
|
|
|
|
|
obs = state.obs[None, :] / 10.0 |
|
|
|
|
|
raw_action = network.forward(obs) |
|
|
|
|
|
thresholds = jnp.array([0.5, 0.5, 0.5]) |
|
binary_action = (raw_action > thresholds).astype(jnp.float32) |
|
|
|
|
|
both_active = jnp.logical_and(binary_action[:, 0] > 0, binary_action[:, 1] > 0) |
|
prefer_left = raw_action[:, 0] > raw_action[:, 1] |
|
|
|
binary_action = binary_action.at[:, 0].set( |
|
jnp.where(both_active, prefer_left.astype(jnp.float32), binary_action[:, 0]) |
|
) |
|
binary_action = binary_action.at[:, 1].set( |
|
jnp.where(both_active, (~prefer_left).astype(jnp.float32), binary_action[:, 1]) |
|
) |
|
|
|
|
|
key, subkey = random.split(key) |
|
state, reward, done = env.step(state, binary_action) |
|
steps += 1 |
|
|
|
|
|
if frames: |
|
frames[0].save( |
|
filename, |
|
save_all=True, |
|
append_images=frames[1:], |
|
duration=50, |
|
loop=0 |
|
) |
|
print(f"Gameplay recorded and saved to {filename}") |
|
else: |
|
print("No frames were recorded") |
|
|
|
def main(): |
|
"""Main training loop with hardware acceleration when available.""" |
|
|
|
env = SlimeVolley(max_steps=1000) |
|
|
|
|
|
config = { |
|
'population_size': 64, |
|
'batch_size': 8, |
|
'weight_mutation_rate': 0.95, |
|
'weight_mutation_power': 4.0, |
|
'add_node_rate': 0.0, |
|
'add_connection_rate': 0.0, |
|
} |
|
|
|
print("\nTraining Configuration:") |
|
print(f"Population Size: {config['population_size']}") |
|
print(f"Batch Size: {config['batch_size']}") |
|
print(f"Mutation Rate: {config['weight_mutation_rate']}") |
|
print("-" * 40) |
|
|
|
|
|
population = [ |
|
Network(Genome(n_inputs=12, n_outputs=3)) |
|
for _ in range(config['population_size']) |
|
] |
|
|
|
best_fitness = float('-inf') |
|
best_network = None |
|
|
|
|
|
for generation in range(1000): |
|
start_time = time.time() |
|
print(f"\nGeneration {generation}") |
|
|
|
|
|
fitness_scores = evaluate_parallel( |
|
population, |
|
env, |
|
batch_size=config['batch_size'] |
|
) |
|
|
|
|
|
max_fitness = max(fitness_scores) |
|
if max_fitness > best_fitness: |
|
best_idx = fitness_scores.index(max_fitness) |
|
best_fitness = max_fitness |
|
best_network = population[best_idx] |
|
print(f"New best fitness: {best_fitness:.2f}") |
|
|
|
|
|
if max_fitness > best_fitness + 2.0: |
|
record_gameplay(best_network, env, f"best_gen_{generation}.gif") |
|
|
|
|
|
if best_fitness > 8.0: |
|
print(f"Target fitness reached: {best_fitness:.2f}") |
|
break |
|
|
|
|
|
population = create_next_generation( |
|
population, |
|
fitness_scores, |
|
config |
|
) |
|
|
|
|
|
if generation % 5 == 0: |
|
gen_time = time.time() - start_time |
|
print(f"\nGeneration {generation} Stats:") |
|
print(f"Best Fitness: {max_fitness:.2f}") |
|
print(f"Average Fitness: {np.mean(fitness_scores):.2f}") |
|
print(f"Generation Time: {gen_time:.2f}s") |
|
|
|
print("\nTraining complete!") |
|
print(f"Best fitness achieved: {best_fitness:.2f}") |
|
|
|
|
|
if best_network: |
|
record_gameplay(best_network, env, "final_gameplay.gif") |
|
|
|
if __name__ == '__main__': |
|
main() |