File size: 16,828 Bytes
0e0538d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
import jax
import jax.numpy as jnp
from jax import random
from evojax.task.slimevolley import SlimeVolley
from typing import List, Tuple, Dict
import numpy as np
import time

class NodeGene:
    def __init__(self, id: int, node_type: str, activation: str = 'tanh'):
        self.id = id
        self.type = node_type  # 'input', 'hidden', or 'output'
        self.activation = activation
        
        # Use both id and timestamp for randomization
        timestamp = int(time.time() * 1000)
        key = random.PRNGKey(hash((id, timestamp)) % (2**32))
        self.bias = float(random.normal(key, shape=()) * 0.1)  # Small random bias

class ConnectionGene:
    def __init__(self, source: int, target: int, weight: float = None, enabled: bool = True):
        self.source = source
        self.target = target
        
        # Use source, target, and timestamp for randomization
        timestamp = int(time.time() * 1000)
        key = random.PRNGKey(hash((source, target, timestamp)) % (2**32))
        
        if weight is None:
            key, subkey = random.split(key)
            weight = float(random.normal(subkey, shape=()) * 0.1)  # Small random weight
        self.weight = weight
        self.enabled = enabled
        self.innovation = hash((source, target))

class Genome:
    def __init__(self, n_inputs: int, n_outputs: int):
        # Create input nodes (0 to n_inputs-1)
        self.node_genes = {i: NodeGene(i, 'input') for i in range(n_inputs)}
        
        # Create exactly 3 output nodes for left, right, jump
        n_outputs = 3  # Force exactly 3 outputs
        for i in range(n_outputs):
            self.node_genes[n_inputs + i] = NodeGene(n_inputs + i, 'output')
        
        self.connection_genes: List[ConnectionGene] = []
        
        # Initialize with randomized connections using unique keys
        timestamp = int(time.time() * 1000)
        master_key = random.PRNGKey(hash((n_inputs, n_outputs, timestamp)) % (2**32))
        
        # Add direct connections with random weights
        for i in range(n_inputs):
            for j in range(n_outputs):
                master_key, key = random.split(master_key)
                if random.uniform(key, shape=()) < 0.7:  # 70% chance of connection
                    master_key, key = random.split(master_key)
                    weight = float(random.normal(key, shape=()) * 0.5)  # Larger initial weights
                    self.connection_genes.append(
                        ConnectionGene(i, n_inputs + j, weight=weight)
                    )
        
        # Add hidden nodes with random connections
        master_key, key = random.split(master_key)
        n_hidden = int(random.randint(key, (), 1, 4))  # Random number of hidden nodes
        hidden_start = n_inputs + n_outputs
        
        for i in range(n_hidden):
            node_id = hidden_start + i
            self.node_genes[node_id] = NodeGene(node_id, 'hidden')
            
            # Connect random inputs to this hidden node
            for j in range(n_inputs):
                master_key, key = random.split(master_key)
                if random.uniform(key, shape=()) < 0.5:
                    master_key, key = random.split(master_key)
                    weight = float(random.normal(key, shape=()) * 0.5)
                    self.connection_genes.append(
                        ConnectionGene(j, node_id, weight=weight)
                    )
            
            # Connect this hidden node to random outputs
            for j in range(n_outputs):
                master_key, key = random.split(master_key)
                if random.uniform(key, shape=()) < 0.5:
                    master_key, key = random.split(master_key)
                    weight = float(random.normal(key, shape=()) * 0.5)
                    self.connection_genes.append(
                        ConnectionGene(node_id, n_inputs + j, weight=weight)
                    )

    def mutate(self, config: Dict):
        key = random.PRNGKey(0)
        
        # Mutate connection weights
        for conn in self.connection_genes:
            key, subkey = random.split(key)
            if random.uniform(subkey, shape=()) < config['weight_mutation_rate']:
                key, subkey = random.split(key)
                # Sometimes reset weight completely
                if random.uniform(subkey, shape=()) < 0.1:
                    key, subkey = random.split(key)
                    conn.weight = float(random.normal(subkey, shape=()) * 0.5)
                else:
                    # Otherwise adjust existing weight
                    key, subkey = random.split(key)
                    conn.weight += float(random.normal(subkey) * config['weight_mutation_power'])
        
        # Mutate node biases
        for node in self.node_genes.values():
            key, subkey = random.split(key)
            if random.uniform(subkey, shape=()) < 0.1:  # 10% chance to mutate bias
                key, subkey = random.split(key)
                node.bias += float(random.normal(subkey) * 0.1)
        
        # Add new node
        key, subkey = random.split(key)
        if random.uniform(subkey, shape=()) < config['add_node_rate']:
            if self.connection_genes:
                # Choose random connection to split
                conn = np.random.choice(self.connection_genes)
                new_id = max(self.node_genes.keys()) + 1
                
                # Create new node with random bias
                self.node_genes[new_id] = NodeGene(new_id, 'hidden')
                
                # Create two new connections with some randomization
                key, subkey = random.split(key)
                weight1 = float(random.normal(subkey, shape=()) * 0.5)
                key, subkey = random.split(key)
                weight2 = float(random.normal(subkey, shape=()) * 0.5)
                
                self.connection_genes.append(
                    ConnectionGene(conn.source, new_id, weight=weight1)
                )
                self.connection_genes.append(
                    ConnectionGene(new_id, conn.target, weight=weight2)
                )
                
                # Disable old connection
                conn.enabled = False

        # Add new connection
        key, subkey = random.split(key)
        if random.uniform(subkey, shape=()) < config['add_connection_rate']:
            # Get all possible nodes
            nodes = list(self.node_genes.keys())
            for _ in range(10):  # Try 10 times to find valid connection
                source = np.random.choice(nodes)
                target = np.random.choice(nodes)
                
                # Ensure forward propagation (source id < target id)
                if source < target:
                    # Check if connection already exists
                    if not any(c.source == source and c.target == target 
                             for c in self.connection_genes):
                        key, subkey = random.split(key)
                        weight = float(random.normal(subkey, shape=()) * 0.5)
                        self.connection_genes.append(
                            ConnectionGene(source, target, weight=weight)
                        )
                        break

class Network:
    def __init__(self, genome: Genome):
        self.genome = genome
        # Sort nodes by ID to ensure consistent ordering
        self.input_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'input'], key=lambda x: x.id)
        self.hidden_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'hidden'], key=lambda x: x.id)
        self.output_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'output'], key=lambda x: x.id)
        
        # Verify we have exactly 3 output nodes
        assert len(self.output_nodes) == 3, f"Expected 3 output nodes, got {len(self.output_nodes)}"
        
    def forward(self, x: jnp.ndarray) -> jnp.ndarray:
        # Ensure input is 2D with shape (batch_size, input_dim)
        if len(x.shape) == 1:
            x = jnp.expand_dims(x, 0)
        
        batch_size = x.shape[0]
        
        # Initialize node values
        values = {}
        for node in self.genome.node_genes.values():
            values[node.id] = jnp.zeros((batch_size,))
            values[node.id] = values[node.id] + node.bias
        
        # Set input values
        for i, node in enumerate(self.input_nodes):
            values[node.id] = x[:, i]
        
        # Process nodes in order
        for node in self.hidden_nodes + self.output_nodes:
            # Sum incoming connections
            total = jnp.zeros((batch_size,))
            total = total + node.bias
            
            for conn in self.genome.connection_genes:
                if conn.enabled and conn.target == node.id:
                    total = total + values[conn.source] * conn.weight
            
            # Apply activation
            values[node.id] = jnp.tanh(total)
        
        # Get output values and ensure shape (batch_size, 3)
        outputs = []
        for node in self.output_nodes:
            outputs.append(values[node.id])
        
        # Stack along new axis to get (batch_size, 3)
        return jnp.stack(outputs, axis=-1)

def evaluate_network(network: Network, env: SlimeVolley, n_episodes: int = 10) -> float:
    total_reward = 0.0
    
    # Generate a unique key for this evaluation
    timestamp = int(time.time() * 1000)
    network_id = id(network)
    master_key = random.PRNGKey(hash((network_id, timestamp)) % (2**32))
    
    for episode in range(n_episodes):
        # Reset environment with proper key shape
        master_key, reset_key = random.split(master_key)
        state = env.reset(reset_key[None, :])  # Add batch dimension
        done = False
        episode_reward = 0.0
        steps = 0
        
        while not done and steps < 1000:  # Add step limit
            # Get observation and normalize
            obs = state.obs[None, :] / 10.0  # Add batch dimension and scale inputs
            
            # Get action from network (shape: batch_size, 3)
            raw_action = network.forward(obs)
            
            # Convert to binary actions using thresholds
            thresholds = jnp.array([0.3, 0.3, 0.4])  # left, right, jump
            binary_action = (raw_action > thresholds).astype(jnp.float32)
            
            # Prevent simultaneous left/right using logical operations
            both_active = jnp.logical_and(binary_action[:, 0] > 0, binary_action[:, 1] > 0)
            prefer_left = raw_action[:, 0] > raw_action[:, 1]
            
            # Update binary action based on preference
            binary_action = binary_action.at[:, 0].set(
                jnp.where(both_active, prefer_left.astype(jnp.float32), binary_action[:, 0])
            )
            binary_action = binary_action.at[:, 1].set(
                jnp.where(both_active, (~prefer_left).astype(jnp.float32), binary_action[:, 1])
            )
            
            # Step environment
            master_key, step_key = random.split(master_key)
            next_state, reward, done = env.step(state, binary_action)  # Already batched
            
            # Process reward and done flag
            if isinstance(reward, jnp.ndarray):
                reward = float(jnp.reshape(reward, (-1,))[0])  # Get first element if batched
            if isinstance(done, jnp.ndarray):
                done = bool(jnp.reshape(done, (-1,))[0])  # Convert to Python bool
            
            # Add small reward for movement to encourage exploration
            any_movement = jnp.any(binary_action[:, :2] > 0)
            movement_reward = 0.1 if bool(any_movement) else 0.0
            
            # Add small reward for keeping ball in play
            ball_height = float(jnp.reshape(next_state.obs[1], (-1,))[0]) if hasattr(next_state.obs, '__getitem__') else 0.0
            height_reward = 0.1 if ball_height > 0.5 else 0.0
            
            # Add reward for ball position and velocity
            ball_x = float(jnp.reshape(next_state.obs[4], (-1,))[0])  # Ball x position
            ball_vx = float(jnp.reshape(next_state.obs[6], (-1,))[0])  # Ball x velocity
            position_reward = 0.2 if ball_x > 0 else 0.0  # Reward for keeping ball on opponent's side
            velocity_reward = 0.1 if ball_vx > 0 else 0.0  # Reward for hitting ball towards opponent
            
            # Calculate step reward with more emphasis on game outcome
            step_reward = reward * 2.0  # Double the importance of winning/losing
            bonus_reward = movement_reward + height_reward + position_reward + velocity_reward
            total_step_reward = step_reward + bonus_reward * 0.5  # Scale down bonus rewards
            
            episode_reward += total_step_reward
            state = next_state
            steps += 1
            
            # Early termination bonus
            if done and reward > 0:  # Won the point
                episode_reward += 10.0
            
        total_reward += episode_reward
        
    return total_reward / n_episodes

def main():
    # Initialize environment
    env = SlimeVolley()
    
    # NEAT configuration
    config = {
        'population_size': 50,  # Smaller population for faster iteration
        'weight_mutation_rate': 0.8,
        'weight_mutation_power': 0.3,  # Increased for more exploration
        'add_node_rate': 0.3,
        'add_connection_rate': 0.5,
    }
    
    # Create initial population
    population = [
        Network(Genome(n_inputs=12, n_outputs=3))
        for _ in range(config['population_size'])
    ]
    
    best_fitness = float('-inf')
    generations_without_improvement = 0
    
    # Evolution loop
    for generation in range(500):  # More generations
        print(f"\nGeneration {generation}")
        print("-" * 20)
        
        # Evaluate population
        fitnesses = []
        for i, net in enumerate(population):
            fitness = evaluate_network(net, env)
            fitnesses.append(fitness)
            print(f"Network {i}: Fitness = {fitness:.2f}")
            
            if fitness > best_fitness:
                best_fitness = fitness
                generations_without_improvement = 0
                print(f"New best fitness: {best_fitness:.2f}")
            
        # Check for improvement
        generations_without_improvement += 1
        if generations_without_improvement > 20:
            print("No improvement for 20 generations, increasing mutation rates")
            config['weight_mutation_rate'] = min(1.0, config['weight_mutation_rate'] * 1.2)
            config['weight_mutation_power'] = min(0.5, config['weight_mutation_power'] * 1.2)
            generations_without_improvement = 0
        
        # Print progress
        avg_fitness = sum(fitnesses) / len(fitnesses)
        print(f"\nBest fitness: {best_fitness:.2f}")
        print(f"Average fitness: {avg_fitness:.2f}")
        
        # Selection and reproduction
        new_population = []
        sorted_indices = np.argsort(fitnesses)[::-1]  # Best to worst
        
        # Keep best networks
        n_elite = 5  # Fewer elites
        new_population.extend([population[i] for i in sorted_indices[:n_elite]])
        print(f"Keeping top {n_elite} networks")
        
        # Create offspring from best networks
        while len(new_population) < config['population_size']:
            # Tournament selection
            tournament_size = 5
            tournament = np.random.choice(sorted_indices[:20], tournament_size, replace=False)
            parent_idx = tournament[np.argmax([fitnesses[i] for i in tournament])]
            parent = population[parent_idx]
            
            # Create offspring
            child_genome = Genome(12, 3)
            child_genome.node_genes = parent.genome.node_genes.copy()
            child_genome.connection_genes = parent.genome.connection_genes.copy()
            
            # Mutate child
            child_genome.mutate(config)
            
            # Add to new population
            new_population.append(Network(child_genome))
            
        population = new_population
        print(f"Created {len(population)} networks for next generation")

if __name__ == '__main__':
    main()