"""Train NEAT networks to play volleyball using hardware acceleration when available.""" import jax import jax.numpy as jnp from jax import random from evojax.task.slimevolley import SlimeVolley from typing import List, Tuple, Dict import numpy as np import time from PIL import Image import io import os # Try to initialize JAX with GPU try: # Configure JAX to use GPU os.environ['CUDA_VISIBLE_DEVICES'] = '0' os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform' # Check available devices print(f"JAX devices available: {jax.devices()}") print(f"Using device: {jax.devices()[0].platform.upper()}") except Exception as e: print(f"Note: Using CPU - {str(e)}") class NodeGene: """A gene representing a node in the neural network.""" def __init__(self, id: int, node_type: str, activation: str = 'tanh'): self.id = id self.type = node_type # 'input', 'hidden', or 'output' self.activation = activation # Use deterministic key generation seed = abs(hash(f"node_{id}")) % (2**32 - 1) # Ensure positive seed key = random.PRNGKey(seed) self.bias = float(random.normal(key, shape=()) * 0.1) class ConnectionGene: """A gene representing a connection between nodes.""" def __init__(self, source: int, target: int, weight: float = None, enabled: bool = True): self.source = source self.target = target self.enabled = enabled self.innovation = hash((source, target)) if weight is None: # Use deterministic key generation seed = abs(hash(f"conn_{source}_{target}")) % (2**32 - 1) key = random.PRNGKey(seed) weight = float(random.normal(key, shape=()) * 0.1) self.weight = weight class Genome: def __init__(self, n_inputs: int, n_outputs: int): # Create input nodes (0 to n_inputs-1) self.node_genes = {i: NodeGene(i, 'input') for i in range(n_inputs)} # Create exactly 3 output nodes for left, right, jump n_outputs = 3 # Force exactly 3 outputs for i in range(n_outputs): self.node_genes[n_inputs + i] = NodeGene(n_inputs + i, 'output') self.connection_genes: List[ConnectionGene] = [] # Initialize with randomized connections using unique keys seed = int(time.time() * 1000) % (2**32 - 1) master_key = random.PRNGKey(seed) # Add direct connections with random weights for i in range(n_inputs): for j in range(n_outputs): master_key, key = random.split(master_key) if random.uniform(key, shape=()) < 0.7: # 70% chance of connection master_key, key = random.split(master_key) weight = float(random.normal(key, shape=()) * 0.5) # Larger initial weights self.connection_genes.append( ConnectionGene(i, n_inputs + j, weight=weight) ) # Add hidden nodes with random connections master_key, key = random.split(master_key) n_hidden = int(random.randint(key, (), 1, 4)) # Random number of hidden nodes hidden_start = n_inputs + n_outputs for i in range(n_hidden): node_id = hidden_start + i self.node_genes[node_id] = NodeGene(node_id, 'hidden') # Connect random inputs to this hidden node for j in range(n_inputs): master_key, key = random.split(master_key) if random.uniform(key, shape=()) < 0.5: master_key, key = random.split(master_key) weight = float(random.normal(key, shape=()) * 0.5) self.connection_genes.append( ConnectionGene(j, node_id, weight=weight) ) # Connect this hidden node to random outputs for j in range(n_outputs): master_key, key = random.split(master_key) if random.uniform(key, shape=()) < 0.5: master_key, key = random.split(master_key) weight = float(random.normal(key, shape=()) * 0.5) self.connection_genes.append( ConnectionGene(node_id, n_inputs + j, weight=weight) ) def mutate(self, config: Dict): seed = int(time.time() * 1000) % (2**32 - 1) key = random.PRNGKey(seed) # Mutate connection weights for conn in self.connection_genes: key, subkey = random.split(key) if random.uniform(subkey, shape=()) < config['weight_mutation_rate']: key, subkey = random.split(key) # Sometimes reset weight completely if random.uniform(subkey, shape=()) < 0.1: key, subkey = random.split(key) conn.weight = float(random.normal(subkey, shape=()) * 0.5) else: # Otherwise adjust existing weight key, subkey = random.split(key) conn.weight += float(random.normal(subkey) * config['weight_mutation_power']) # Mutate node biases for node in self.node_genes.values(): key, subkey = random.split(key) if random.uniform(subkey, shape=()) < 0.1: # 10% chance to mutate bias key, subkey = random.split(key) node.bias += float(random.normal(subkey) * 0.1) # Add new node key, subkey = random.split(key) if random.uniform(subkey, shape=()) < config['add_node_rate']: if self.connection_genes: # Choose random connection to split conn = np.random.choice(self.connection_genes) new_id = max(self.node_genes.keys()) + 1 # Create new node with random bias self.node_genes[new_id] = NodeGene(new_id, 'hidden') # Create two new connections with some randomization key, subkey = random.split(key) weight1 = float(random.normal(subkey, shape=()) * 0.5) key, subkey = random.split(key) weight2 = float(random.normal(subkey, shape=()) * 0.5) self.connection_genes.append( ConnectionGene(conn.source, new_id, weight=weight1) ) self.connection_genes.append( ConnectionGene(new_id, conn.target, weight=weight2) ) # Disable old connection conn.enabled = False # Add new connection key, subkey = random.split(key) if random.uniform(subkey, shape=()) < config['add_connection_rate']: # Get all possible nodes nodes = list(self.node_genes.keys()) for _ in range(10): # Try 10 times to find valid connection source = np.random.choice(nodes) target = np.random.choice(nodes) # Ensure forward propagation (source id < target id) if source < target: # Check if connection already exists if not any(c.source == source and c.target == target for c in self.connection_genes): key, subkey = random.split(key) weight = float(random.normal(subkey, shape=()) * 0.5) self.connection_genes.append( ConnectionGene(source, target, weight=weight) ) break class Network: def __init__(self, genome: Genome): self.genome = genome # Sort nodes by ID to ensure consistent ordering self.input_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'input'], key=lambda x: x.id) self.hidden_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'hidden'], key=lambda x: x.id) self.output_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'output'], key=lambda x: x.id) # Verify we have exactly 3 output nodes assert len(self.output_nodes) == 3, f"Expected 3 output nodes, got {len(self.output_nodes)}" def forward(self, x: jnp.ndarray) -> jnp.ndarray: # Ensure input is 2D with shape (batch_size, input_dim) if len(x.shape) == 1: x = jnp.expand_dims(x, 0) batch_size = x.shape[0] # Initialize node values values = {} for node in self.genome.node_genes.values(): values[node.id] = jnp.zeros((batch_size,)) values[node.id] = values[node.id] + node.bias # Set input values for i, node in enumerate(self.input_nodes): values[node.id] = x[:, i] # Process nodes in order for node in self.hidden_nodes + self.output_nodes: # Sum incoming connections total = jnp.zeros((batch_size,)) total = total + node.bias for conn in self.genome.connection_genes: if conn.enabled and conn.target == node.id: total = total + values[conn.source] * conn.weight # Apply activation values[node.id] = jnp.tanh(total) # Get output values and ensure shape (batch_size, 3) outputs = [] for node in self.output_nodes: outputs.append(values[node.id]) # Stack along new axis to get (batch_size, 3) return jnp.stack(outputs, axis=-1) def evaluate_parallel(networks: List[Network], env: SlimeVolley, batch_size: int = 8) -> List[float]: """Evaluate multiple networks in parallel using JAX's vectorization.""" total_networks = len(networks) fitness_scores = [] for i in range(0, total_networks, batch_size): batch = networks[i:i + batch_size] batch_size_actual = len(batch) # Initialize environment states with proper key shape seed = int(time.time() * 1000) % (2**32 - 1) key = random.PRNGKey(seed) states = env.reset(key) total_rewards = np.zeros(batch_size_actual) # Run episodes for step in range(1000): # Max steps per episode # Get observations and normalize observations = states.obs / 10.0 # Get actions from all networks actions = np.stack([ net.forward(obs[None, :]) for net, obs in zip(batch, observations) ]) # Convert to binary actions thresholds = np.array([0.5, 0.5, 0.5]) binary_actions = (actions > thresholds).astype(np.float32) # Step environment key, subkey = random.split(key) next_states, rewards, dones = env.step(states, binary_actions) total_rewards += np.array([float(r) for r in rewards]) states = next_states if np.all(dones): break fitness_scores.extend(list(total_rewards)) return fitness_scores def create_next_generation(population: List[Network], fitness_scores: List[float], config: Dict): """Create the next generation of networks based on the current population and fitness scores.""" next_population = [] # Keep top 20% unchanged (less elitism = faster adaptation) n_elite = max(2, int(0.2 * len(population))) next_population.extend(population[:n_elite]) # Fill rest with mutated versions of top 50% n_top = max(5, int(0.5 * len(population))) while len(next_population) < len(population): # Tournament selection with size 3 (smaller = faster) tournament_size = 3 candidates = np.random.choice(population[:n_top], tournament_size, replace=False) parent = max(candidates, key=lambda x: fitness_scores[population.index(x)]) child = Network(parent.genome) child.genome.mutate(config) next_population.append(child) return next_population def record_gameplay(network: Network, env: SlimeVolley, filename: str = 'gameplay.gif', max_steps: int = 1000): """Record a game played by the network and save it as a GIF.""" frames = [] # Initialize environment seed = int(time.time() * 1000) % (2**32 - 1) key = random.PRNGKey(seed) state = env.reset(key) done = False steps = 0 while not done and steps < max_steps: # Render current frame frame = env.render(state) frames.append(frame) # frame is already a PIL Image # Get observation and normalize obs = state.obs[None, :] / 10.0 # Get action from network raw_action = network.forward(obs) # Convert to binary actions thresholds = jnp.array([0.5, 0.5, 0.5]) binary_action = (raw_action > thresholds).astype(jnp.float32) # Prevent simultaneous left/right both_active = jnp.logical_and(binary_action[:, 0] > 0, binary_action[:, 1] > 0) prefer_left = raw_action[:, 0] > raw_action[:, 1] binary_action = binary_action.at[:, 0].set( jnp.where(both_active, prefer_left.astype(jnp.float32), binary_action[:, 0]) ) binary_action = binary_action.at[:, 1].set( jnp.where(both_active, (~prefer_left).astype(jnp.float32), binary_action[:, 1]) ) # Step environment key, subkey = random.split(key) # Get new key for each step state, reward, done = env.step(state, binary_action) # Already batched steps += 1 # Save as GIF if frames: frames[0].save( filename, save_all=True, append_images=frames[1:], duration=50, # 20 fps loop=0 ) print(f"Gameplay recorded and saved to {filename}") else: print("No frames were recorded") def main(): """Main training loop with hardware acceleration when available.""" # Initialize environment env = SlimeVolley(max_steps=1000) # Configuration for evolution config = { 'population_size': 64, 'batch_size': 8, # Smaller batch size for better compatibility 'weight_mutation_rate': 0.95, 'weight_mutation_power': 4.0, 'add_node_rate': 0.0, 'add_connection_rate': 0.0, } print("\nTraining Configuration:") print(f"Population Size: {config['population_size']}") print(f"Batch Size: {config['batch_size']}") print(f"Mutation Rate: {config['weight_mutation_rate']}") print("-" * 40) # Create initial population population = [ Network(Genome(n_inputs=12, n_outputs=3)) for _ in range(config['population_size']) ] best_fitness = float('-inf') best_network = None # Evolution loop for generation in range(1000): start_time = time.time() print(f"\nGeneration {generation}") # Evaluate population in batches fitness_scores = evaluate_parallel( population, env, batch_size=config['batch_size'] ) # Track best network max_fitness = max(fitness_scores) if max_fitness > best_fitness: best_idx = fitness_scores.index(max_fitness) best_fitness = max_fitness best_network = population[best_idx] print(f"New best fitness: {best_fitness:.2f}") # Record gameplay for significant improvements if max_fitness > best_fitness + 2.0: record_gameplay(best_network, env, f"best_gen_{generation}.gif") # Early stopping if best_fitness > 8.0: print(f"Target fitness reached: {best_fitness:.2f}") break # Create next generation population = create_next_generation( population, fitness_scores, config ) # Print stats every 5 generations if generation % 5 == 0: gen_time = time.time() - start_time print(f"\nGeneration {generation} Stats:") print(f"Best Fitness: {max_fitness:.2f}") print(f"Average Fitness: {np.mean(fitness_scores):.2f}") print(f"Generation Time: {gen_time:.2f}s") print("\nTraining complete!") print(f"Best fitness achieved: {best_fitness:.2f}") # Save final network if best_network: record_gameplay(best_network, env, "final_gameplay.gif") if __name__ == '__main__': main()