File size: 16,828 Bytes
0e0538d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 |
import jax
import jax.numpy as jnp
from jax import random
from evojax.task.slimevolley import SlimeVolley
from typing import List, Tuple, Dict
import numpy as np
import time
class NodeGene:
def __init__(self, id: int, node_type: str, activation: str = 'tanh'):
self.id = id
self.type = node_type # 'input', 'hidden', or 'output'
self.activation = activation
# Use both id and timestamp for randomization
timestamp = int(time.time() * 1000)
key = random.PRNGKey(hash((id, timestamp)) % (2**32))
self.bias = float(random.normal(key, shape=()) * 0.1) # Small random bias
class ConnectionGene:
def __init__(self, source: int, target: int, weight: float = None, enabled: bool = True):
self.source = source
self.target = target
# Use source, target, and timestamp for randomization
timestamp = int(time.time() * 1000)
key = random.PRNGKey(hash((source, target, timestamp)) % (2**32))
if weight is None:
key, subkey = random.split(key)
weight = float(random.normal(subkey, shape=()) * 0.1) # Small random weight
self.weight = weight
self.enabled = enabled
self.innovation = hash((source, target))
class Genome:
def __init__(self, n_inputs: int, n_outputs: int):
# Create input nodes (0 to n_inputs-1)
self.node_genes = {i: NodeGene(i, 'input') for i in range(n_inputs)}
# Create exactly 3 output nodes for left, right, jump
n_outputs = 3 # Force exactly 3 outputs
for i in range(n_outputs):
self.node_genes[n_inputs + i] = NodeGene(n_inputs + i, 'output')
self.connection_genes: List[ConnectionGene] = []
# Initialize with randomized connections using unique keys
timestamp = int(time.time() * 1000)
master_key = random.PRNGKey(hash((n_inputs, n_outputs, timestamp)) % (2**32))
# Add direct connections with random weights
for i in range(n_inputs):
for j in range(n_outputs):
master_key, key = random.split(master_key)
if random.uniform(key, shape=()) < 0.7: # 70% chance of connection
master_key, key = random.split(master_key)
weight = float(random.normal(key, shape=()) * 0.5) # Larger initial weights
self.connection_genes.append(
ConnectionGene(i, n_inputs + j, weight=weight)
)
# Add hidden nodes with random connections
master_key, key = random.split(master_key)
n_hidden = int(random.randint(key, (), 1, 4)) # Random number of hidden nodes
hidden_start = n_inputs + n_outputs
for i in range(n_hidden):
node_id = hidden_start + i
self.node_genes[node_id] = NodeGene(node_id, 'hidden')
# Connect random inputs to this hidden node
for j in range(n_inputs):
master_key, key = random.split(master_key)
if random.uniform(key, shape=()) < 0.5:
master_key, key = random.split(master_key)
weight = float(random.normal(key, shape=()) * 0.5)
self.connection_genes.append(
ConnectionGene(j, node_id, weight=weight)
)
# Connect this hidden node to random outputs
for j in range(n_outputs):
master_key, key = random.split(master_key)
if random.uniform(key, shape=()) < 0.5:
master_key, key = random.split(master_key)
weight = float(random.normal(key, shape=()) * 0.5)
self.connection_genes.append(
ConnectionGene(node_id, n_inputs + j, weight=weight)
)
def mutate(self, config: Dict):
key = random.PRNGKey(0)
# Mutate connection weights
for conn in self.connection_genes:
key, subkey = random.split(key)
if random.uniform(subkey, shape=()) < config['weight_mutation_rate']:
key, subkey = random.split(key)
# Sometimes reset weight completely
if random.uniform(subkey, shape=()) < 0.1:
key, subkey = random.split(key)
conn.weight = float(random.normal(subkey, shape=()) * 0.5)
else:
# Otherwise adjust existing weight
key, subkey = random.split(key)
conn.weight += float(random.normal(subkey) * config['weight_mutation_power'])
# Mutate node biases
for node in self.node_genes.values():
key, subkey = random.split(key)
if random.uniform(subkey, shape=()) < 0.1: # 10% chance to mutate bias
key, subkey = random.split(key)
node.bias += float(random.normal(subkey) * 0.1)
# Add new node
key, subkey = random.split(key)
if random.uniform(subkey, shape=()) < config['add_node_rate']:
if self.connection_genes:
# Choose random connection to split
conn = np.random.choice(self.connection_genes)
new_id = max(self.node_genes.keys()) + 1
# Create new node with random bias
self.node_genes[new_id] = NodeGene(new_id, 'hidden')
# Create two new connections with some randomization
key, subkey = random.split(key)
weight1 = float(random.normal(subkey, shape=()) * 0.5)
key, subkey = random.split(key)
weight2 = float(random.normal(subkey, shape=()) * 0.5)
self.connection_genes.append(
ConnectionGene(conn.source, new_id, weight=weight1)
)
self.connection_genes.append(
ConnectionGene(new_id, conn.target, weight=weight2)
)
# Disable old connection
conn.enabled = False
# Add new connection
key, subkey = random.split(key)
if random.uniform(subkey, shape=()) < config['add_connection_rate']:
# Get all possible nodes
nodes = list(self.node_genes.keys())
for _ in range(10): # Try 10 times to find valid connection
source = np.random.choice(nodes)
target = np.random.choice(nodes)
# Ensure forward propagation (source id < target id)
if source < target:
# Check if connection already exists
if not any(c.source == source and c.target == target
for c in self.connection_genes):
key, subkey = random.split(key)
weight = float(random.normal(subkey, shape=()) * 0.5)
self.connection_genes.append(
ConnectionGene(source, target, weight=weight)
)
break
class Network:
def __init__(self, genome: Genome):
self.genome = genome
# Sort nodes by ID to ensure consistent ordering
self.input_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'input'], key=lambda x: x.id)
self.hidden_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'hidden'], key=lambda x: x.id)
self.output_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'output'], key=lambda x: x.id)
# Verify we have exactly 3 output nodes
assert len(self.output_nodes) == 3, f"Expected 3 output nodes, got {len(self.output_nodes)}"
def forward(self, x: jnp.ndarray) -> jnp.ndarray:
# Ensure input is 2D with shape (batch_size, input_dim)
if len(x.shape) == 1:
x = jnp.expand_dims(x, 0)
batch_size = x.shape[0]
# Initialize node values
values = {}
for node in self.genome.node_genes.values():
values[node.id] = jnp.zeros((batch_size,))
values[node.id] = values[node.id] + node.bias
# Set input values
for i, node in enumerate(self.input_nodes):
values[node.id] = x[:, i]
# Process nodes in order
for node in self.hidden_nodes + self.output_nodes:
# Sum incoming connections
total = jnp.zeros((batch_size,))
total = total + node.bias
for conn in self.genome.connection_genes:
if conn.enabled and conn.target == node.id:
total = total + values[conn.source] * conn.weight
# Apply activation
values[node.id] = jnp.tanh(total)
# Get output values and ensure shape (batch_size, 3)
outputs = []
for node in self.output_nodes:
outputs.append(values[node.id])
# Stack along new axis to get (batch_size, 3)
return jnp.stack(outputs, axis=-1)
def evaluate_network(network: Network, env: SlimeVolley, n_episodes: int = 10) -> float:
total_reward = 0.0
# Generate a unique key for this evaluation
timestamp = int(time.time() * 1000)
network_id = id(network)
master_key = random.PRNGKey(hash((network_id, timestamp)) % (2**32))
for episode in range(n_episodes):
# Reset environment with proper key shape
master_key, reset_key = random.split(master_key)
state = env.reset(reset_key[None, :]) # Add batch dimension
done = False
episode_reward = 0.0
steps = 0
while not done and steps < 1000: # Add step limit
# Get observation and normalize
obs = state.obs[None, :] / 10.0 # Add batch dimension and scale inputs
# Get action from network (shape: batch_size, 3)
raw_action = network.forward(obs)
# Convert to binary actions using thresholds
thresholds = jnp.array([0.3, 0.3, 0.4]) # left, right, jump
binary_action = (raw_action > thresholds).astype(jnp.float32)
# Prevent simultaneous left/right using logical operations
both_active = jnp.logical_and(binary_action[:, 0] > 0, binary_action[:, 1] > 0)
prefer_left = raw_action[:, 0] > raw_action[:, 1]
# Update binary action based on preference
binary_action = binary_action.at[:, 0].set(
jnp.where(both_active, prefer_left.astype(jnp.float32), binary_action[:, 0])
)
binary_action = binary_action.at[:, 1].set(
jnp.where(both_active, (~prefer_left).astype(jnp.float32), binary_action[:, 1])
)
# Step environment
master_key, step_key = random.split(master_key)
next_state, reward, done = env.step(state, binary_action) # Already batched
# Process reward and done flag
if isinstance(reward, jnp.ndarray):
reward = float(jnp.reshape(reward, (-1,))[0]) # Get first element if batched
if isinstance(done, jnp.ndarray):
done = bool(jnp.reshape(done, (-1,))[0]) # Convert to Python bool
# Add small reward for movement to encourage exploration
any_movement = jnp.any(binary_action[:, :2] > 0)
movement_reward = 0.1 if bool(any_movement) else 0.0
# Add small reward for keeping ball in play
ball_height = float(jnp.reshape(next_state.obs[1], (-1,))[0]) if hasattr(next_state.obs, '__getitem__') else 0.0
height_reward = 0.1 if ball_height > 0.5 else 0.0
# Add reward for ball position and velocity
ball_x = float(jnp.reshape(next_state.obs[4], (-1,))[0]) # Ball x position
ball_vx = float(jnp.reshape(next_state.obs[6], (-1,))[0]) # Ball x velocity
position_reward = 0.2 if ball_x > 0 else 0.0 # Reward for keeping ball on opponent's side
velocity_reward = 0.1 if ball_vx > 0 else 0.0 # Reward for hitting ball towards opponent
# Calculate step reward with more emphasis on game outcome
step_reward = reward * 2.0 # Double the importance of winning/losing
bonus_reward = movement_reward + height_reward + position_reward + velocity_reward
total_step_reward = step_reward + bonus_reward * 0.5 # Scale down bonus rewards
episode_reward += total_step_reward
state = next_state
steps += 1
# Early termination bonus
if done and reward > 0: # Won the point
episode_reward += 10.0
total_reward += episode_reward
return total_reward / n_episodes
def main():
# Initialize environment
env = SlimeVolley()
# NEAT configuration
config = {
'population_size': 50, # Smaller population for faster iteration
'weight_mutation_rate': 0.8,
'weight_mutation_power': 0.3, # Increased for more exploration
'add_node_rate': 0.3,
'add_connection_rate': 0.5,
}
# Create initial population
population = [
Network(Genome(n_inputs=12, n_outputs=3))
for _ in range(config['population_size'])
]
best_fitness = float('-inf')
generations_without_improvement = 0
# Evolution loop
for generation in range(500): # More generations
print(f"\nGeneration {generation}")
print("-" * 20)
# Evaluate population
fitnesses = []
for i, net in enumerate(population):
fitness = evaluate_network(net, env)
fitnesses.append(fitness)
print(f"Network {i}: Fitness = {fitness:.2f}")
if fitness > best_fitness:
best_fitness = fitness
generations_without_improvement = 0
print(f"New best fitness: {best_fitness:.2f}")
# Check for improvement
generations_without_improvement += 1
if generations_without_improvement > 20:
print("No improvement for 20 generations, increasing mutation rates")
config['weight_mutation_rate'] = min(1.0, config['weight_mutation_rate'] * 1.2)
config['weight_mutation_power'] = min(0.5, config['weight_mutation_power'] * 1.2)
generations_without_improvement = 0
# Print progress
avg_fitness = sum(fitnesses) / len(fitnesses)
print(f"\nBest fitness: {best_fitness:.2f}")
print(f"Average fitness: {avg_fitness:.2f}")
# Selection and reproduction
new_population = []
sorted_indices = np.argsort(fitnesses)[::-1] # Best to worst
# Keep best networks
n_elite = 5 # Fewer elites
new_population.extend([population[i] for i in sorted_indices[:n_elite]])
print(f"Keeping top {n_elite} networks")
# Create offspring from best networks
while len(new_population) < config['population_size']:
# Tournament selection
tournament_size = 5
tournament = np.random.choice(sorted_indices[:20], tournament_size, replace=False)
parent_idx = tournament[np.argmax([fitnesses[i] for i in tournament])]
parent = population[parent_idx]
# Create offspring
child_genome = Genome(12, 3)
child_genome.node_genes = parent.genome.node_genes.copy()
child_genome.connection_genes = parent.genome.connection_genes.copy()
# Mutate child
child_genome.mutate(config)
# Add to new population
new_population.append(Network(child_genome))
population = new_population
print(f"Created {len(population)} networks for next generation")
if __name__ == '__main__':
main() |