eyad-silx commited on
Commit
0e0538d
·
verified ·
1 Parent(s): 8bc3d16

Upload old_train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. old_train.py +377 -0
old_train.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from jax import random
4
+ from evojax.task.slimevolley import SlimeVolley
5
+ from typing import List, Tuple, Dict
6
+ import numpy as np
7
+ import time
8
+
9
+ class NodeGene:
10
+ def __init__(self, id: int, node_type: str, activation: str = 'tanh'):
11
+ self.id = id
12
+ self.type = node_type # 'input', 'hidden', or 'output'
13
+ self.activation = activation
14
+
15
+ # Use both id and timestamp for randomization
16
+ timestamp = int(time.time() * 1000)
17
+ key = random.PRNGKey(hash((id, timestamp)) % (2**32))
18
+ self.bias = float(random.normal(key, shape=()) * 0.1) # Small random bias
19
+
20
+ class ConnectionGene:
21
+ def __init__(self, source: int, target: int, weight: float = None, enabled: bool = True):
22
+ self.source = source
23
+ self.target = target
24
+
25
+ # Use source, target, and timestamp for randomization
26
+ timestamp = int(time.time() * 1000)
27
+ key = random.PRNGKey(hash((source, target, timestamp)) % (2**32))
28
+
29
+ if weight is None:
30
+ key, subkey = random.split(key)
31
+ weight = float(random.normal(subkey, shape=()) * 0.1) # Small random weight
32
+ self.weight = weight
33
+ self.enabled = enabled
34
+ self.innovation = hash((source, target))
35
+
36
+ class Genome:
37
+ def __init__(self, n_inputs: int, n_outputs: int):
38
+ # Create input nodes (0 to n_inputs-1)
39
+ self.node_genes = {i: NodeGene(i, 'input') for i in range(n_inputs)}
40
+
41
+ # Create exactly 3 output nodes for left, right, jump
42
+ n_outputs = 3 # Force exactly 3 outputs
43
+ for i in range(n_outputs):
44
+ self.node_genes[n_inputs + i] = NodeGene(n_inputs + i, 'output')
45
+
46
+ self.connection_genes: List[ConnectionGene] = []
47
+
48
+ # Initialize with randomized connections using unique keys
49
+ timestamp = int(time.time() * 1000)
50
+ master_key = random.PRNGKey(hash((n_inputs, n_outputs, timestamp)) % (2**32))
51
+
52
+ # Add direct connections with random weights
53
+ for i in range(n_inputs):
54
+ for j in range(n_outputs):
55
+ master_key, key = random.split(master_key)
56
+ if random.uniform(key, shape=()) < 0.7: # 70% chance of connection
57
+ master_key, key = random.split(master_key)
58
+ weight = float(random.normal(key, shape=()) * 0.5) # Larger initial weights
59
+ self.connection_genes.append(
60
+ ConnectionGene(i, n_inputs + j, weight=weight)
61
+ )
62
+
63
+ # Add hidden nodes with random connections
64
+ master_key, key = random.split(master_key)
65
+ n_hidden = int(random.randint(key, (), 1, 4)) # Random number of hidden nodes
66
+ hidden_start = n_inputs + n_outputs
67
+
68
+ for i in range(n_hidden):
69
+ node_id = hidden_start + i
70
+ self.node_genes[node_id] = NodeGene(node_id, 'hidden')
71
+
72
+ # Connect random inputs to this hidden node
73
+ for j in range(n_inputs):
74
+ master_key, key = random.split(master_key)
75
+ if random.uniform(key, shape=()) < 0.5:
76
+ master_key, key = random.split(master_key)
77
+ weight = float(random.normal(key, shape=()) * 0.5)
78
+ self.connection_genes.append(
79
+ ConnectionGene(j, node_id, weight=weight)
80
+ )
81
+
82
+ # Connect this hidden node to random outputs
83
+ for j in range(n_outputs):
84
+ master_key, key = random.split(master_key)
85
+ if random.uniform(key, shape=()) < 0.5:
86
+ master_key, key = random.split(master_key)
87
+ weight = float(random.normal(key, shape=()) * 0.5)
88
+ self.connection_genes.append(
89
+ ConnectionGene(node_id, n_inputs + j, weight=weight)
90
+ )
91
+
92
+ def mutate(self, config: Dict):
93
+ key = random.PRNGKey(0)
94
+
95
+ # Mutate connection weights
96
+ for conn in self.connection_genes:
97
+ key, subkey = random.split(key)
98
+ if random.uniform(subkey, shape=()) < config['weight_mutation_rate']:
99
+ key, subkey = random.split(key)
100
+ # Sometimes reset weight completely
101
+ if random.uniform(subkey, shape=()) < 0.1:
102
+ key, subkey = random.split(key)
103
+ conn.weight = float(random.normal(subkey, shape=()) * 0.5)
104
+ else:
105
+ # Otherwise adjust existing weight
106
+ key, subkey = random.split(key)
107
+ conn.weight += float(random.normal(subkey) * config['weight_mutation_power'])
108
+
109
+ # Mutate node biases
110
+ for node in self.node_genes.values():
111
+ key, subkey = random.split(key)
112
+ if random.uniform(subkey, shape=()) < 0.1: # 10% chance to mutate bias
113
+ key, subkey = random.split(key)
114
+ node.bias += float(random.normal(subkey) * 0.1)
115
+
116
+ # Add new node
117
+ key, subkey = random.split(key)
118
+ if random.uniform(subkey, shape=()) < config['add_node_rate']:
119
+ if self.connection_genes:
120
+ # Choose random connection to split
121
+ conn = np.random.choice(self.connection_genes)
122
+ new_id = max(self.node_genes.keys()) + 1
123
+
124
+ # Create new node with random bias
125
+ self.node_genes[new_id] = NodeGene(new_id, 'hidden')
126
+
127
+ # Create two new connections with some randomization
128
+ key, subkey = random.split(key)
129
+ weight1 = float(random.normal(subkey, shape=()) * 0.5)
130
+ key, subkey = random.split(key)
131
+ weight2 = float(random.normal(subkey, shape=()) * 0.5)
132
+
133
+ self.connection_genes.append(
134
+ ConnectionGene(conn.source, new_id, weight=weight1)
135
+ )
136
+ self.connection_genes.append(
137
+ ConnectionGene(new_id, conn.target, weight=weight2)
138
+ )
139
+
140
+ # Disable old connection
141
+ conn.enabled = False
142
+
143
+ # Add new connection
144
+ key, subkey = random.split(key)
145
+ if random.uniform(subkey, shape=()) < config['add_connection_rate']:
146
+ # Get all possible nodes
147
+ nodes = list(self.node_genes.keys())
148
+ for _ in range(10): # Try 10 times to find valid connection
149
+ source = np.random.choice(nodes)
150
+ target = np.random.choice(nodes)
151
+
152
+ # Ensure forward propagation (source id < target id)
153
+ if source < target:
154
+ # Check if connection already exists
155
+ if not any(c.source == source and c.target == target
156
+ for c in self.connection_genes):
157
+ key, subkey = random.split(key)
158
+ weight = float(random.normal(subkey, shape=()) * 0.5)
159
+ self.connection_genes.append(
160
+ ConnectionGene(source, target, weight=weight)
161
+ )
162
+ break
163
+
164
+ class Network:
165
+ def __init__(self, genome: Genome):
166
+ self.genome = genome
167
+ # Sort nodes by ID to ensure consistent ordering
168
+ self.input_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'input'], key=lambda x: x.id)
169
+ self.hidden_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'hidden'], key=lambda x: x.id)
170
+ self.output_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'output'], key=lambda x: x.id)
171
+
172
+ # Verify we have exactly 3 output nodes
173
+ assert len(self.output_nodes) == 3, f"Expected 3 output nodes, got {len(self.output_nodes)}"
174
+
175
+ def forward(self, x: jnp.ndarray) -> jnp.ndarray:
176
+ # Ensure input is 2D with shape (batch_size, input_dim)
177
+ if len(x.shape) == 1:
178
+ x = jnp.expand_dims(x, 0)
179
+
180
+ batch_size = x.shape[0]
181
+
182
+ # Initialize node values
183
+ values = {}
184
+ for node in self.genome.node_genes.values():
185
+ values[node.id] = jnp.zeros((batch_size,))
186
+ values[node.id] = values[node.id] + node.bias
187
+
188
+ # Set input values
189
+ for i, node in enumerate(self.input_nodes):
190
+ values[node.id] = x[:, i]
191
+
192
+ # Process nodes in order
193
+ for node in self.hidden_nodes + self.output_nodes:
194
+ # Sum incoming connections
195
+ total = jnp.zeros((batch_size,))
196
+ total = total + node.bias
197
+
198
+ for conn in self.genome.connection_genes:
199
+ if conn.enabled and conn.target == node.id:
200
+ total = total + values[conn.source] * conn.weight
201
+
202
+ # Apply activation
203
+ values[node.id] = jnp.tanh(total)
204
+
205
+ # Get output values and ensure shape (batch_size, 3)
206
+ outputs = []
207
+ for node in self.output_nodes:
208
+ outputs.append(values[node.id])
209
+
210
+ # Stack along new axis to get (batch_size, 3)
211
+ return jnp.stack(outputs, axis=-1)
212
+
213
+ def evaluate_network(network: Network, env: SlimeVolley, n_episodes: int = 10) -> float:
214
+ total_reward = 0.0
215
+
216
+ # Generate a unique key for this evaluation
217
+ timestamp = int(time.time() * 1000)
218
+ network_id = id(network)
219
+ master_key = random.PRNGKey(hash((network_id, timestamp)) % (2**32))
220
+
221
+ for episode in range(n_episodes):
222
+ # Reset environment with proper key shape
223
+ master_key, reset_key = random.split(master_key)
224
+ state = env.reset(reset_key[None, :]) # Add batch dimension
225
+ done = False
226
+ episode_reward = 0.0
227
+ steps = 0
228
+
229
+ while not done and steps < 1000: # Add step limit
230
+ # Get observation and normalize
231
+ obs = state.obs[None, :] / 10.0 # Add batch dimension and scale inputs
232
+
233
+ # Get action from network (shape: batch_size, 3)
234
+ raw_action = network.forward(obs)
235
+
236
+ # Convert to binary actions using thresholds
237
+ thresholds = jnp.array([0.3, 0.3, 0.4]) # left, right, jump
238
+ binary_action = (raw_action > thresholds).astype(jnp.float32)
239
+
240
+ # Prevent simultaneous left/right using logical operations
241
+ both_active = jnp.logical_and(binary_action[:, 0] > 0, binary_action[:, 1] > 0)
242
+ prefer_left = raw_action[:, 0] > raw_action[:, 1]
243
+
244
+ # Update binary action based on preference
245
+ binary_action = binary_action.at[:, 0].set(
246
+ jnp.where(both_active, prefer_left.astype(jnp.float32), binary_action[:, 0])
247
+ )
248
+ binary_action = binary_action.at[:, 1].set(
249
+ jnp.where(both_active, (~prefer_left).astype(jnp.float32), binary_action[:, 1])
250
+ )
251
+
252
+ # Step environment
253
+ master_key, step_key = random.split(master_key)
254
+ next_state, reward, done = env.step(state, binary_action) # Already batched
255
+
256
+ # Process reward and done flag
257
+ if isinstance(reward, jnp.ndarray):
258
+ reward = float(jnp.reshape(reward, (-1,))[0]) # Get first element if batched
259
+ if isinstance(done, jnp.ndarray):
260
+ done = bool(jnp.reshape(done, (-1,))[0]) # Convert to Python bool
261
+
262
+ # Add small reward for movement to encourage exploration
263
+ any_movement = jnp.any(binary_action[:, :2] > 0)
264
+ movement_reward = 0.1 if bool(any_movement) else 0.0
265
+
266
+ # Add small reward for keeping ball in play
267
+ ball_height = float(jnp.reshape(next_state.obs[1], (-1,))[0]) if hasattr(next_state.obs, '__getitem__') else 0.0
268
+ height_reward = 0.1 if ball_height > 0.5 else 0.0
269
+
270
+ # Add reward for ball position and velocity
271
+ ball_x = float(jnp.reshape(next_state.obs[4], (-1,))[0]) # Ball x position
272
+ ball_vx = float(jnp.reshape(next_state.obs[6], (-1,))[0]) # Ball x velocity
273
+ position_reward = 0.2 if ball_x > 0 else 0.0 # Reward for keeping ball on opponent's side
274
+ velocity_reward = 0.1 if ball_vx > 0 else 0.0 # Reward for hitting ball towards opponent
275
+
276
+ # Calculate step reward with more emphasis on game outcome
277
+ step_reward = reward * 2.0 # Double the importance of winning/losing
278
+ bonus_reward = movement_reward + height_reward + position_reward + velocity_reward
279
+ total_step_reward = step_reward + bonus_reward * 0.5 # Scale down bonus rewards
280
+
281
+ episode_reward += total_step_reward
282
+ state = next_state
283
+ steps += 1
284
+
285
+ # Early termination bonus
286
+ if done and reward > 0: # Won the point
287
+ episode_reward += 10.0
288
+
289
+ total_reward += episode_reward
290
+
291
+ return total_reward / n_episodes
292
+
293
+ def main():
294
+ # Initialize environment
295
+ env = SlimeVolley()
296
+
297
+ # NEAT configuration
298
+ config = {
299
+ 'population_size': 50, # Smaller population for faster iteration
300
+ 'weight_mutation_rate': 0.8,
301
+ 'weight_mutation_power': 0.3, # Increased for more exploration
302
+ 'add_node_rate': 0.3,
303
+ 'add_connection_rate': 0.5,
304
+ }
305
+
306
+ # Create initial population
307
+ population = [
308
+ Network(Genome(n_inputs=12, n_outputs=3))
309
+ for _ in range(config['population_size'])
310
+ ]
311
+
312
+ best_fitness = float('-inf')
313
+ generations_without_improvement = 0
314
+
315
+ # Evolution loop
316
+ for generation in range(500): # More generations
317
+ print(f"\nGeneration {generation}")
318
+ print("-" * 20)
319
+
320
+ # Evaluate population
321
+ fitnesses = []
322
+ for i, net in enumerate(population):
323
+ fitness = evaluate_network(net, env)
324
+ fitnesses.append(fitness)
325
+ print(f"Network {i}: Fitness = {fitness:.2f}")
326
+
327
+ if fitness > best_fitness:
328
+ best_fitness = fitness
329
+ generations_without_improvement = 0
330
+ print(f"New best fitness: {best_fitness:.2f}")
331
+
332
+ # Check for improvement
333
+ generations_without_improvement += 1
334
+ if generations_without_improvement > 20:
335
+ print("No improvement for 20 generations, increasing mutation rates")
336
+ config['weight_mutation_rate'] = min(1.0, config['weight_mutation_rate'] * 1.2)
337
+ config['weight_mutation_power'] = min(0.5, config['weight_mutation_power'] * 1.2)
338
+ generations_without_improvement = 0
339
+
340
+ # Print progress
341
+ avg_fitness = sum(fitnesses) / len(fitnesses)
342
+ print(f"\nBest fitness: {best_fitness:.2f}")
343
+ print(f"Average fitness: {avg_fitness:.2f}")
344
+
345
+ # Selection and reproduction
346
+ new_population = []
347
+ sorted_indices = np.argsort(fitnesses)[::-1] # Best to worst
348
+
349
+ # Keep best networks
350
+ n_elite = 5 # Fewer elites
351
+ new_population.extend([population[i] for i in sorted_indices[:n_elite]])
352
+ print(f"Keeping top {n_elite} networks")
353
+
354
+ # Create offspring from best networks
355
+ while len(new_population) < config['population_size']:
356
+ # Tournament selection
357
+ tournament_size = 5
358
+ tournament = np.random.choice(sorted_indices[:20], tournament_size, replace=False)
359
+ parent_idx = tournament[np.argmax([fitnesses[i] for i in tournament])]
360
+ parent = population[parent_idx]
361
+
362
+ # Create offspring
363
+ child_genome = Genome(12, 3)
364
+ child_genome.node_genes = parent.genome.node_genes.copy()
365
+ child_genome.connection_genes = parent.genome.connection_genes.copy()
366
+
367
+ # Mutate child
368
+ child_genome.mutate(config)
369
+
370
+ # Add to new population
371
+ new_population.append(Network(child_genome))
372
+
373
+ population = new_population
374
+ print(f"Created {len(population)} networks for next generation")
375
+
376
+ if __name__ == '__main__':
377
+ main()