neat / train.py
eyad-silx's picture
Upload train.py with huggingface_hub
ecccd48 verified
"""Train NEAT networks to play volleyball using hardware acceleration when available."""
import jax
import jax.numpy as jnp
from jax import random
from evojax.task.slimevolley import SlimeVolley
from typing import List, Tuple, Dict
import numpy as np
import time
from PIL import Image
import io
import os
# Try to initialize JAX with GPU
try:
# Configure JAX to use GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
# Check available devices
print(f"JAX devices available: {jax.devices()}")
print(f"Using device: {jax.devices()[0].platform.upper()}")
except Exception as e:
print(f"Note: Using CPU - {str(e)}")
class NodeGene:
"""A gene representing a node in the neural network."""
def __init__(self, id: int, node_type: str, activation: str = 'tanh'):
self.id = id
self.type = node_type # 'input', 'hidden', or 'output'
self.activation = activation
# Use deterministic key generation
seed = abs(hash(f"node_{id}")) % (2**32 - 1) # Ensure positive seed
key = random.PRNGKey(seed)
self.bias = float(random.normal(key, shape=()) * 0.1)
class ConnectionGene:
"""A gene representing a connection between nodes."""
def __init__(self, source: int, target: int, weight: float = None, enabled: bool = True):
self.source = source
self.target = target
self.enabled = enabled
self.innovation = hash((source, target))
if weight is None:
# Use deterministic key generation
seed = abs(hash(f"conn_{source}_{target}")) % (2**32 - 1)
key = random.PRNGKey(seed)
weight = float(random.normal(key, shape=()) * 0.1)
self.weight = weight
class Genome:
def __init__(self, n_inputs: int, n_outputs: int):
# Create input nodes (0 to n_inputs-1)
self.node_genes = {i: NodeGene(i, 'input') for i in range(n_inputs)}
# Create exactly 3 output nodes for left, right, jump
n_outputs = 3 # Force exactly 3 outputs
for i in range(n_outputs):
self.node_genes[n_inputs + i] = NodeGene(n_inputs + i, 'output')
self.connection_genes: List[ConnectionGene] = []
# Initialize with randomized connections using unique keys
seed = int(time.time() * 1000) % (2**32 - 1)
master_key = random.PRNGKey(seed)
# Add direct connections with random weights
for i in range(n_inputs):
for j in range(n_outputs):
master_key, key = random.split(master_key)
if random.uniform(key, shape=()) < 0.7: # 70% chance of connection
master_key, key = random.split(master_key)
weight = float(random.normal(key, shape=()) * 0.5) # Larger initial weights
self.connection_genes.append(
ConnectionGene(i, n_inputs + j, weight=weight)
)
# Add hidden nodes with random connections
master_key, key = random.split(master_key)
n_hidden = int(random.randint(key, (), 1, 4)) # Random number of hidden nodes
hidden_start = n_inputs + n_outputs
for i in range(n_hidden):
node_id = hidden_start + i
self.node_genes[node_id] = NodeGene(node_id, 'hidden')
# Connect random inputs to this hidden node
for j in range(n_inputs):
master_key, key = random.split(master_key)
if random.uniform(key, shape=()) < 0.5:
master_key, key = random.split(master_key)
weight = float(random.normal(key, shape=()) * 0.5)
self.connection_genes.append(
ConnectionGene(j, node_id, weight=weight)
)
# Connect this hidden node to random outputs
for j in range(n_outputs):
master_key, key = random.split(master_key)
if random.uniform(key, shape=()) < 0.5:
master_key, key = random.split(master_key)
weight = float(random.normal(key, shape=()) * 0.5)
self.connection_genes.append(
ConnectionGene(node_id, n_inputs + j, weight=weight)
)
def mutate(self, config: Dict):
seed = int(time.time() * 1000) % (2**32 - 1)
key = random.PRNGKey(seed)
# Mutate connection weights
for conn in self.connection_genes:
key, subkey = random.split(key)
if random.uniform(subkey, shape=()) < config['weight_mutation_rate']:
key, subkey = random.split(key)
# Sometimes reset weight completely
if random.uniform(subkey, shape=()) < 0.1:
key, subkey = random.split(key)
conn.weight = float(random.normal(subkey, shape=()) * 0.5)
else:
# Otherwise adjust existing weight
key, subkey = random.split(key)
conn.weight += float(random.normal(subkey) * config['weight_mutation_power'])
# Mutate node biases
for node in self.node_genes.values():
key, subkey = random.split(key)
if random.uniform(subkey, shape=()) < 0.1: # 10% chance to mutate bias
key, subkey = random.split(key)
node.bias += float(random.normal(subkey) * 0.1)
# Add new node
key, subkey = random.split(key)
if random.uniform(subkey, shape=()) < config['add_node_rate']:
if self.connection_genes:
# Choose random connection to split
conn = np.random.choice(self.connection_genes)
new_id = max(self.node_genes.keys()) + 1
# Create new node with random bias
self.node_genes[new_id] = NodeGene(new_id, 'hidden')
# Create two new connections with some randomization
key, subkey = random.split(key)
weight1 = float(random.normal(subkey, shape=()) * 0.5)
key, subkey = random.split(key)
weight2 = float(random.normal(subkey, shape=()) * 0.5)
self.connection_genes.append(
ConnectionGene(conn.source, new_id, weight=weight1)
)
self.connection_genes.append(
ConnectionGene(new_id, conn.target, weight=weight2)
)
# Disable old connection
conn.enabled = False
# Add new connection
key, subkey = random.split(key)
if random.uniform(subkey, shape=()) < config['add_connection_rate']:
# Get all possible nodes
nodes = list(self.node_genes.keys())
for _ in range(10): # Try 10 times to find valid connection
source = np.random.choice(nodes)
target = np.random.choice(nodes)
# Ensure forward propagation (source id < target id)
if source < target:
# Check if connection already exists
if not any(c.source == source and c.target == target
for c in self.connection_genes):
key, subkey = random.split(key)
weight = float(random.normal(subkey, shape=()) * 0.5)
self.connection_genes.append(
ConnectionGene(source, target, weight=weight)
)
break
class Network:
def __init__(self, genome: Genome):
self.genome = genome
# Sort nodes by ID to ensure consistent ordering
self.input_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'input'], key=lambda x: x.id)
self.hidden_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'hidden'], key=lambda x: x.id)
self.output_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'output'], key=lambda x: x.id)
# Verify we have exactly 3 output nodes
assert len(self.output_nodes) == 3, f"Expected 3 output nodes, got {len(self.output_nodes)}"
def forward(self, x: jnp.ndarray) -> jnp.ndarray:
# Ensure input is 2D with shape (batch_size, input_dim)
if len(x.shape) == 1:
x = jnp.expand_dims(x, 0)
batch_size = x.shape[0]
# Initialize node values
values = {}
for node in self.genome.node_genes.values():
values[node.id] = jnp.zeros((batch_size,))
values[node.id] = values[node.id] + node.bias
# Set input values
for i, node in enumerate(self.input_nodes):
values[node.id] = x[:, i]
# Process nodes in order
for node in self.hidden_nodes + self.output_nodes:
# Sum incoming connections
total = jnp.zeros((batch_size,))
total = total + node.bias
for conn in self.genome.connection_genes:
if conn.enabled and conn.target == node.id:
total = total + values[conn.source] * conn.weight
# Apply activation
values[node.id] = jnp.tanh(total)
# Get output values and ensure shape (batch_size, 3)
outputs = []
for node in self.output_nodes:
outputs.append(values[node.id])
# Stack along new axis to get (batch_size, 3)
return jnp.stack(outputs, axis=-1)
def evaluate_parallel(networks: List[Network], env: SlimeVolley, batch_size: int = 8) -> List[float]:
"""Evaluate multiple networks in parallel using JAX's vectorization."""
total_networks = len(networks)
fitness_scores = []
for i in range(0, total_networks, batch_size):
batch = networks[i:i + batch_size]
batch_size_actual = len(batch)
# Initialize environment states with proper key shape
seed = int(time.time() * 1000) % (2**32 - 1)
key = random.PRNGKey(seed)
states = env.reset(key)
total_rewards = np.zeros(batch_size_actual)
# Run episodes
for step in range(1000): # Max steps per episode
# Get observations and normalize
observations = states.obs / 10.0
# Get actions from all networks
actions = np.stack([
net.forward(obs[None, :])
for net, obs in zip(batch, observations)
])
# Convert to binary actions
thresholds = np.array([0.5, 0.5, 0.5])
binary_actions = (actions > thresholds).astype(np.float32)
# Step environment
key, subkey = random.split(key)
next_states, rewards, dones = env.step(states, binary_actions)
total_rewards += np.array([float(r) for r in rewards])
states = next_states
if np.all(dones):
break
fitness_scores.extend(list(total_rewards))
return fitness_scores
def create_next_generation(population: List[Network], fitness_scores: List[float], config: Dict):
"""Create the next generation of networks based on the current population and fitness scores."""
next_population = []
# Keep top 20% unchanged (less elitism = faster adaptation)
n_elite = max(2, int(0.2 * len(population)))
next_population.extend(population[:n_elite])
# Fill rest with mutated versions of top 50%
n_top = max(5, int(0.5 * len(population)))
while len(next_population) < len(population):
# Tournament selection with size 3 (smaller = faster)
tournament_size = 3
candidates = np.random.choice(population[:n_top], tournament_size, replace=False)
parent = max(candidates, key=lambda x: fitness_scores[population.index(x)])
child = Network(parent.genome)
child.genome.mutate(config)
next_population.append(child)
return next_population
def record_gameplay(network: Network, env: SlimeVolley, filename: str = 'gameplay.gif', max_steps: int = 1000):
"""Record a game played by the network and save it as a GIF."""
frames = []
# Initialize environment
seed = int(time.time() * 1000) % (2**32 - 1)
key = random.PRNGKey(seed)
state = env.reset(key)
done = False
steps = 0
while not done and steps < max_steps:
# Render current frame
frame = env.render(state)
frames.append(frame) # frame is already a PIL Image
# Get observation and normalize
obs = state.obs[None, :] / 10.0
# Get action from network
raw_action = network.forward(obs)
# Convert to binary actions
thresholds = jnp.array([0.5, 0.5, 0.5])
binary_action = (raw_action > thresholds).astype(jnp.float32)
# Prevent simultaneous left/right
both_active = jnp.logical_and(binary_action[:, 0] > 0, binary_action[:, 1] > 0)
prefer_left = raw_action[:, 0] > raw_action[:, 1]
binary_action = binary_action.at[:, 0].set(
jnp.where(both_active, prefer_left.astype(jnp.float32), binary_action[:, 0])
)
binary_action = binary_action.at[:, 1].set(
jnp.where(both_active, (~prefer_left).astype(jnp.float32), binary_action[:, 1])
)
# Step environment
key, subkey = random.split(key) # Get new key for each step
state, reward, done = env.step(state, binary_action) # Already batched
steps += 1
# Save as GIF
if frames:
frames[0].save(
filename,
save_all=True,
append_images=frames[1:],
duration=50, # 20 fps
loop=0
)
print(f"Gameplay recorded and saved to {filename}")
else:
print("No frames were recorded")
def main():
"""Main training loop with hardware acceleration when available."""
# Initialize environment
env = SlimeVolley(max_steps=1000)
# Configuration for evolution
config = {
'population_size': 64,
'batch_size': 8, # Smaller batch size for better compatibility
'weight_mutation_rate': 0.95,
'weight_mutation_power': 4.0,
'add_node_rate': 0.0,
'add_connection_rate': 0.0,
}
print("\nTraining Configuration:")
print(f"Population Size: {config['population_size']}")
print(f"Batch Size: {config['batch_size']}")
print(f"Mutation Rate: {config['weight_mutation_rate']}")
print("-" * 40)
# Create initial population
population = [
Network(Genome(n_inputs=12, n_outputs=3))
for _ in range(config['population_size'])
]
best_fitness = float('-inf')
best_network = None
# Evolution loop
for generation in range(1000):
start_time = time.time()
print(f"\nGeneration {generation}")
# Evaluate population in batches
fitness_scores = evaluate_parallel(
population,
env,
batch_size=config['batch_size']
)
# Track best network
max_fitness = max(fitness_scores)
if max_fitness > best_fitness:
best_idx = fitness_scores.index(max_fitness)
best_fitness = max_fitness
best_network = population[best_idx]
print(f"New best fitness: {best_fitness:.2f}")
# Record gameplay for significant improvements
if max_fitness > best_fitness + 2.0:
record_gameplay(best_network, env, f"best_gen_{generation}.gif")
# Early stopping
if best_fitness > 8.0:
print(f"Target fitness reached: {best_fitness:.2f}")
break
# Create next generation
population = create_next_generation(
population,
fitness_scores,
config
)
# Print stats every 5 generations
if generation % 5 == 0:
gen_time = time.time() - start_time
print(f"\nGeneration {generation} Stats:")
print(f"Best Fitness: {max_fitness:.2f}")
print(f"Average Fitness: {np.mean(fitness_scores):.2f}")
print(f"Generation Time: {gen_time:.2f}s")
print("\nTraining complete!")
print(f"Best fitness achieved: {best_fitness:.2f}")
# Save final network
if best_network:
record_gameplay(best_network, env, "final_gameplay.gif")
if __name__ == '__main__':
main()