eyad-silx commited on
Commit
41ff170
·
verified ·
1 Parent(s): 80f8293

Upload neat\evolution.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. neat//evolution.py +302 -0
neat//evolution.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """NEAT evolution implementation."""
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import numpy as np
6
+ from typing import List, Dict, Optional, Tuple, Callable
7
+ from .network import Network
8
+ from .genome import Genome
9
+
10
+ class NEATEvolution:
11
+ """NEAT evolution implementation with structural mutations."""
12
+
13
+ DEFAULT_CONFIG = {
14
+ 'node_add_prob': 0.2, # Standard node addition rate
15
+ 'conn_add_prob': 0.3, # Standard connection addition rate
16
+ 'weight_mutate_prob': 0.8, # High chance of weight mutation
17
+ 'weight_replace_prob': 0.1, # Low chance of complete weight replacement
18
+ 'weight_perturb_size': 0.5, # Standard weight perturbation size
19
+ 'bias_mutate_prob': 0.8, # High chance of bias mutation
20
+ 'bias_replace_prob': 0.1, # Low chance of complete bias replacement
21
+ 'bias_perturb_size': 0.5, # Standard bias perturbation size
22
+ 'complexity_coefficient': 0.0, # No complexity penalty
23
+ 'species_distance': 2.0, # Standard species distance
24
+ 'species_elitism': 2, # Keep top 2 from each species
25
+ 'survival_threshold': 0.3 # Keep 30% of population
26
+ }
27
+
28
+ def __init__(self,
29
+ n_inputs: int,
30
+ n_outputs: int,
31
+ population_size: int,
32
+ config: Optional[Dict] = None,
33
+ key: Optional[jnp.ndarray] = None):
34
+ """Initialize NEAT evolution.
35
+
36
+ Args:
37
+ n_inputs: Number of input nodes (12 for volleyball)
38
+ n_outputs: Number of output nodes (3 for volleyball)
39
+ population_size: Size of population
40
+ config: Optional configuration parameters
41
+ key: Random key for JAX
42
+ """
43
+ self.n_inputs = n_inputs
44
+ self.n_outputs = n_outputs
45
+ self.population_size = population_size
46
+ self.config = {**self.DEFAULT_CONFIG, **(config or {})}
47
+
48
+ # Initialize random key
49
+ if key is None:
50
+ self.key = jax.random.PRNGKey(0)
51
+ else:
52
+ self.key = key
53
+
54
+ # Initialize population
55
+ self.population = self._init_population()
56
+ self.generation = 0
57
+ self.innovation_number = 0
58
+ self.species = []
59
+
60
+ def _init_population(self) -> List[Genome]:
61
+ """Initialize population with minimal networks."""
62
+ population = []
63
+ for _ in range(self.population_size):
64
+ # Split random key
65
+ self.key, subkey = jax.random.split(self.key)
66
+
67
+ # Create genome with proper input/output sizes
68
+ genome = Genome(self.n_inputs, self.n_outputs, subkey)
69
+
70
+ # Add random hidden nodes (between 2-6)
71
+ self.key, subkey = jax.random.split(self.key)
72
+ n_hidden = int(jax.random.randint(subkey, (), 2, 7))
73
+
74
+ hidden_nodes = []
75
+ for _ in range(n_hidden):
76
+ hidden_nodes.append(genome.add_node())
77
+
78
+ # Connect inputs to hidden with 50% probability
79
+ for i in range(self.n_inputs):
80
+ for h in hidden_nodes:
81
+ self.key, subkey = jax.random.split(self.key)
82
+ if jax.random.uniform(subkey) < 0.5:
83
+ self.key, subkey = jax.random.split(self.key)
84
+ weight = jax.random.normal(subkey) * 0.5
85
+ genome.add_connection(i, h, weight)
86
+
87
+ # Connect hidden to outputs with 50% probability
88
+ output_start = genome.n_nodes - self.n_outputs
89
+ for h in hidden_nodes:
90
+ for i in range(self.n_outputs):
91
+ self.key, subkey = jax.random.split(self.key)
92
+ if jax.random.uniform(subkey) < 0.5:
93
+ self.key, subkey = jax.random.split(self.key)
94
+ weight = jax.random.normal(subkey) * 0.5
95
+ genome.add_connection(h, output_start + i, weight)
96
+
97
+ # Add skip connections with 30% probability
98
+ for i in range(self.n_inputs):
99
+ for j in range(self.n_outputs):
100
+ self.key, subkey = jax.random.split(self.key)
101
+ if jax.random.uniform(subkey) < 0.3:
102
+ self.key, subkey = jax.random.split(self.key)
103
+ weight = jax.random.normal(subkey) * 0.3
104
+ genome.add_connection(i, output_start + j, weight)
105
+
106
+ population.append(genome)
107
+ return population
108
+
109
+ def ask(self) -> List[Network]:
110
+ """Get current population as networks."""
111
+ return [Network(genome) for genome in self.population]
112
+
113
+ def tell(self, fitnesses: List[float]) -> None:
114
+ """Update population based on fitness scores."""
115
+ # Sort population by fitness
116
+ sorted_pop = sorted(zip(self.population, fitnesses),
117
+ key=lambda x: x[1], reverse=True)
118
+
119
+ # For very small populations, keep at least one parent
120
+ n_parents = max(1, int(self.population_size * self.config['survival_threshold']))
121
+ parents = [p for p, _ in sorted_pop[:n_parents]]
122
+
123
+ # Ensure we have at least one parent
124
+ if not parents:
125
+ # If all fitnesses are equal (including all zeros), keep the first one
126
+ parents = [sorted_pop[0][0]]
127
+
128
+ # Create new population starting with the best performer
129
+ new_population = [parents[0]] # Always keep the best one
130
+
131
+ # Fill rest with mutated offspring
132
+ while len(new_population) < self.population_size:
133
+ # Select parent (with replacement)
134
+ parent = parents[0] if len(parents) == 1 else np.random.choice(parents)
135
+ child = parent.copy()
136
+
137
+ # Mutate child
138
+ child = self._mutate_genome(child, self.key)
139
+
140
+ new_population.append(child)
141
+
142
+ self.population = new_population
143
+ self.generation += 1
144
+
145
+ def _mutate_genome(self, genome: Genome, key: jnp.ndarray) -> Genome:
146
+ """Mutate a genome.
147
+
148
+ Mutation types:
149
+ 1. Add new nodes (30% chance)
150
+ 2. Add new connections (50% chance)
151
+ 3. Modify weights (80% chance)
152
+ 4. Modify biases (70% chance)
153
+ 5. Enable/disable connections (20% chance)
154
+ """
155
+ # Split random key
156
+ keys = jax.random.split(key, 6)
157
+
158
+ # Add nodes
159
+ if jax.random.uniform(keys[0]) < self.config['node_add_prob']:
160
+ # Add 1-3 nodes with decreasing probability
161
+ n_nodes = 1
162
+ while jax.random.uniform(keys[1]) < 0.3 and n_nodes < 4:
163
+ # Pick random enabled connection
164
+ enabled_conns = [(src, dst) for (src, dst), enabled in genome.connections.items() if enabled]
165
+ if enabled_conns:
166
+ src, dst = enabled_conns[int(jax.random.randint(keys[2], (), 0, len(enabled_conns)))]
167
+ genome.add_node_between(src, dst)
168
+ n_nodes += 1
169
+
170
+ # Add connections
171
+ if jax.random.uniform(keys[1]) < self.config['conn_add_prob']:
172
+ # Add multiple connections with decreasing probability
173
+ n_conns = 0
174
+ max_attempts = 20 # Prevent infinite loops
175
+ attempts = 0
176
+
177
+ while attempts < max_attempts and n_conns < 5:
178
+ # Pick random nodes
179
+ src = int(jax.random.randint(keys[2], (), 0, genome.n_nodes))
180
+ dst = int(jax.random.randint(keys[3], (), 0, genome.n_nodes))
181
+
182
+ # Add connection if valid and not already present
183
+ if src != dst and (src, dst) not in genome.connections:
184
+ weight = jax.random.normal(keys[4]) * 0.5
185
+ genome.add_connection(src, dst, weight)
186
+ n_conns += 1
187
+ attempts += 1
188
+
189
+ # Mutate weights
190
+ if jax.random.uniform(keys[2]) < self.config['weight_mutate_prob']:
191
+ for conn in list(genome.connections.keys()):
192
+ if genome.connections[conn]: # Only mutate enabled connections
193
+ if jax.random.uniform(keys[3]) < self.config['weight_replace_prob']:
194
+ # Reset weight
195
+ genome.weights[conn] = jax.random.normal(keys[4]) * self.config['weight_perturb_size']
196
+ else:
197
+ # Perturb weight
198
+ genome.weights[conn] += jax.random.normal(keys[4]) * self.config['weight_perturb_size']
199
+
200
+ # Mutate biases
201
+ if jax.random.uniform(keys[3]) < self.config['bias_mutate_prob']:
202
+ for node in list(genome.biases.keys()):
203
+ if jax.random.uniform(keys[4]) < self.config['bias_replace_prob']:
204
+ # Reset bias
205
+ genome.biases[node] = jax.random.normal(keys[5]) * self.config['bias_perturb_size']
206
+ else:
207
+ # Perturb bias
208
+ genome.biases[node] += jax.random.normal(keys[5]) * self.config['bias_perturb_size']
209
+
210
+ # Enable/disable connections
211
+ for conn in list(genome.connections.keys()):
212
+ if jax.random.uniform(keys[5]) < 0.2: # 20% chance per connection
213
+ genome.connections[conn] = not genome.connections[conn]
214
+
215
+ return genome
216
+
217
+ def get_average_nodes(self) -> float:
218
+ """Get average number of nodes in population."""
219
+ return np.mean([g.n_nodes for g in self.population])
220
+
221
+ def get_average_connections(self) -> float:
222
+ """Get average number of connections in population."""
223
+ return np.mean([len(g.connections) for g in self.population])
224
+
225
+ def get_activation_distribution(self) -> Dict[str, float]:
226
+ """Get distribution of activation functions in population.
227
+
228
+ Returns:
229
+ Dictionary mapping activation function names to their frequency
230
+ """
231
+ # For now we only use ReLU
232
+ return {'relu': 1.0}
233
+
234
+ def run_evolution(self, evaluator: Callable[[Network], float], max_generations: int,
235
+ fitness_threshold: float, reset_mutations: bool = True,
236
+ max_stagnation: int = 15, verbose: bool = True) -> Tuple[Network, float]:
237
+ """Run the evolution process
238
+
239
+ Args:
240
+ evaluator: Function that takes a network and returns its fitness
241
+ max_generations: Maximum number of generations to run
242
+ fitness_threshold: Target fitness to achieve
243
+ reset_mutations: Whether to reset mutations when fitness improves
244
+ max_stagnation: Maximum generations without improvement before stopping
245
+ verbose: Whether to print progress
246
+
247
+ Returns:
248
+ Tuple of (best network, best fitness)
249
+ """
250
+ best_fitness = float('-inf')
251
+ best_network = None
252
+ stagnation_counter = 0
253
+
254
+ for generation in range(max_generations):
255
+ # Evaluate current population
256
+ fitnesses = []
257
+ for genome in self.population:
258
+ network = genome.to_network()
259
+ fitness = evaluator(network)
260
+ genome.fitness = fitness
261
+ fitnesses.append(fitness)
262
+
263
+ # Update best if improved
264
+ if fitness > best_fitness:
265
+ best_fitness = fitness
266
+ best_network = network
267
+ stagnation_counter = 0
268
+ if reset_mutations:
269
+ self.reset_innovation()
270
+
271
+ # Get statistics
272
+ avg_fitness = sum(fitnesses) / len(fitnesses)
273
+ generation_best = max(fitnesses)
274
+
275
+ # Print progress
276
+ if verbose:
277
+ print(f"\nGeneration {generation}:")
278
+ print(f" Best Fitness: {best_fitness:.2f}")
279
+ print(f" Generation Best: {generation_best:.2f}")
280
+ print(f" Average Nodes: {self.get_average_nodes():.1f}")
281
+ print(f" Average Connections: {self.get_average_connections():.1f}")
282
+
283
+ # Check for improvement
284
+ if generation_best <= best_fitness:
285
+ stagnation_counter += 1
286
+ else:
287
+ stagnation_counter = 0
288
+
289
+ # Create next generation
290
+ self.tell(fitnesses)
291
+
292
+ # Stop if stagnated too long
293
+ if stagnation_counter >= max_stagnation:
294
+ if verbose:
295
+ print(f"\nStopping: No improvement for {max_stagnation} generations")
296
+ break
297
+
298
+ if verbose:
299
+ print("\nTraining complete!")
300
+ print(f"Best fitness achieved: {best_fitness:.2f}")
301
+
302
+ return best_network, best_fitness