Upload neat\evolution.py with huggingface_hub
Browse files- neat//evolution.py +302 -0
neat//evolution.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""NEAT evolution implementation."""
|
2 |
+
|
3 |
+
import jax
|
4 |
+
import jax.numpy as jnp
|
5 |
+
import numpy as np
|
6 |
+
from typing import List, Dict, Optional, Tuple, Callable
|
7 |
+
from .network import Network
|
8 |
+
from .genome import Genome
|
9 |
+
|
10 |
+
class NEATEvolution:
|
11 |
+
"""NEAT evolution implementation with structural mutations."""
|
12 |
+
|
13 |
+
DEFAULT_CONFIG = {
|
14 |
+
'node_add_prob': 0.2, # Standard node addition rate
|
15 |
+
'conn_add_prob': 0.3, # Standard connection addition rate
|
16 |
+
'weight_mutate_prob': 0.8, # High chance of weight mutation
|
17 |
+
'weight_replace_prob': 0.1, # Low chance of complete weight replacement
|
18 |
+
'weight_perturb_size': 0.5, # Standard weight perturbation size
|
19 |
+
'bias_mutate_prob': 0.8, # High chance of bias mutation
|
20 |
+
'bias_replace_prob': 0.1, # Low chance of complete bias replacement
|
21 |
+
'bias_perturb_size': 0.5, # Standard bias perturbation size
|
22 |
+
'complexity_coefficient': 0.0, # No complexity penalty
|
23 |
+
'species_distance': 2.0, # Standard species distance
|
24 |
+
'species_elitism': 2, # Keep top 2 from each species
|
25 |
+
'survival_threshold': 0.3 # Keep 30% of population
|
26 |
+
}
|
27 |
+
|
28 |
+
def __init__(self,
|
29 |
+
n_inputs: int,
|
30 |
+
n_outputs: int,
|
31 |
+
population_size: int,
|
32 |
+
config: Optional[Dict] = None,
|
33 |
+
key: Optional[jnp.ndarray] = None):
|
34 |
+
"""Initialize NEAT evolution.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
n_inputs: Number of input nodes (12 for volleyball)
|
38 |
+
n_outputs: Number of output nodes (3 for volleyball)
|
39 |
+
population_size: Size of population
|
40 |
+
config: Optional configuration parameters
|
41 |
+
key: Random key for JAX
|
42 |
+
"""
|
43 |
+
self.n_inputs = n_inputs
|
44 |
+
self.n_outputs = n_outputs
|
45 |
+
self.population_size = population_size
|
46 |
+
self.config = {**self.DEFAULT_CONFIG, **(config or {})}
|
47 |
+
|
48 |
+
# Initialize random key
|
49 |
+
if key is None:
|
50 |
+
self.key = jax.random.PRNGKey(0)
|
51 |
+
else:
|
52 |
+
self.key = key
|
53 |
+
|
54 |
+
# Initialize population
|
55 |
+
self.population = self._init_population()
|
56 |
+
self.generation = 0
|
57 |
+
self.innovation_number = 0
|
58 |
+
self.species = []
|
59 |
+
|
60 |
+
def _init_population(self) -> List[Genome]:
|
61 |
+
"""Initialize population with minimal networks."""
|
62 |
+
population = []
|
63 |
+
for _ in range(self.population_size):
|
64 |
+
# Split random key
|
65 |
+
self.key, subkey = jax.random.split(self.key)
|
66 |
+
|
67 |
+
# Create genome with proper input/output sizes
|
68 |
+
genome = Genome(self.n_inputs, self.n_outputs, subkey)
|
69 |
+
|
70 |
+
# Add random hidden nodes (between 2-6)
|
71 |
+
self.key, subkey = jax.random.split(self.key)
|
72 |
+
n_hidden = int(jax.random.randint(subkey, (), 2, 7))
|
73 |
+
|
74 |
+
hidden_nodes = []
|
75 |
+
for _ in range(n_hidden):
|
76 |
+
hidden_nodes.append(genome.add_node())
|
77 |
+
|
78 |
+
# Connect inputs to hidden with 50% probability
|
79 |
+
for i in range(self.n_inputs):
|
80 |
+
for h in hidden_nodes:
|
81 |
+
self.key, subkey = jax.random.split(self.key)
|
82 |
+
if jax.random.uniform(subkey) < 0.5:
|
83 |
+
self.key, subkey = jax.random.split(self.key)
|
84 |
+
weight = jax.random.normal(subkey) * 0.5
|
85 |
+
genome.add_connection(i, h, weight)
|
86 |
+
|
87 |
+
# Connect hidden to outputs with 50% probability
|
88 |
+
output_start = genome.n_nodes - self.n_outputs
|
89 |
+
for h in hidden_nodes:
|
90 |
+
for i in range(self.n_outputs):
|
91 |
+
self.key, subkey = jax.random.split(self.key)
|
92 |
+
if jax.random.uniform(subkey) < 0.5:
|
93 |
+
self.key, subkey = jax.random.split(self.key)
|
94 |
+
weight = jax.random.normal(subkey) * 0.5
|
95 |
+
genome.add_connection(h, output_start + i, weight)
|
96 |
+
|
97 |
+
# Add skip connections with 30% probability
|
98 |
+
for i in range(self.n_inputs):
|
99 |
+
for j in range(self.n_outputs):
|
100 |
+
self.key, subkey = jax.random.split(self.key)
|
101 |
+
if jax.random.uniform(subkey) < 0.3:
|
102 |
+
self.key, subkey = jax.random.split(self.key)
|
103 |
+
weight = jax.random.normal(subkey) * 0.3
|
104 |
+
genome.add_connection(i, output_start + j, weight)
|
105 |
+
|
106 |
+
population.append(genome)
|
107 |
+
return population
|
108 |
+
|
109 |
+
def ask(self) -> List[Network]:
|
110 |
+
"""Get current population as networks."""
|
111 |
+
return [Network(genome) for genome in self.population]
|
112 |
+
|
113 |
+
def tell(self, fitnesses: List[float]) -> None:
|
114 |
+
"""Update population based on fitness scores."""
|
115 |
+
# Sort population by fitness
|
116 |
+
sorted_pop = sorted(zip(self.population, fitnesses),
|
117 |
+
key=lambda x: x[1], reverse=True)
|
118 |
+
|
119 |
+
# For very small populations, keep at least one parent
|
120 |
+
n_parents = max(1, int(self.population_size * self.config['survival_threshold']))
|
121 |
+
parents = [p for p, _ in sorted_pop[:n_parents]]
|
122 |
+
|
123 |
+
# Ensure we have at least one parent
|
124 |
+
if not parents:
|
125 |
+
# If all fitnesses are equal (including all zeros), keep the first one
|
126 |
+
parents = [sorted_pop[0][0]]
|
127 |
+
|
128 |
+
# Create new population starting with the best performer
|
129 |
+
new_population = [parents[0]] # Always keep the best one
|
130 |
+
|
131 |
+
# Fill rest with mutated offspring
|
132 |
+
while len(new_population) < self.population_size:
|
133 |
+
# Select parent (with replacement)
|
134 |
+
parent = parents[0] if len(parents) == 1 else np.random.choice(parents)
|
135 |
+
child = parent.copy()
|
136 |
+
|
137 |
+
# Mutate child
|
138 |
+
child = self._mutate_genome(child, self.key)
|
139 |
+
|
140 |
+
new_population.append(child)
|
141 |
+
|
142 |
+
self.population = new_population
|
143 |
+
self.generation += 1
|
144 |
+
|
145 |
+
def _mutate_genome(self, genome: Genome, key: jnp.ndarray) -> Genome:
|
146 |
+
"""Mutate a genome.
|
147 |
+
|
148 |
+
Mutation types:
|
149 |
+
1. Add new nodes (30% chance)
|
150 |
+
2. Add new connections (50% chance)
|
151 |
+
3. Modify weights (80% chance)
|
152 |
+
4. Modify biases (70% chance)
|
153 |
+
5. Enable/disable connections (20% chance)
|
154 |
+
"""
|
155 |
+
# Split random key
|
156 |
+
keys = jax.random.split(key, 6)
|
157 |
+
|
158 |
+
# Add nodes
|
159 |
+
if jax.random.uniform(keys[0]) < self.config['node_add_prob']:
|
160 |
+
# Add 1-3 nodes with decreasing probability
|
161 |
+
n_nodes = 1
|
162 |
+
while jax.random.uniform(keys[1]) < 0.3 and n_nodes < 4:
|
163 |
+
# Pick random enabled connection
|
164 |
+
enabled_conns = [(src, dst) for (src, dst), enabled in genome.connections.items() if enabled]
|
165 |
+
if enabled_conns:
|
166 |
+
src, dst = enabled_conns[int(jax.random.randint(keys[2], (), 0, len(enabled_conns)))]
|
167 |
+
genome.add_node_between(src, dst)
|
168 |
+
n_nodes += 1
|
169 |
+
|
170 |
+
# Add connections
|
171 |
+
if jax.random.uniform(keys[1]) < self.config['conn_add_prob']:
|
172 |
+
# Add multiple connections with decreasing probability
|
173 |
+
n_conns = 0
|
174 |
+
max_attempts = 20 # Prevent infinite loops
|
175 |
+
attempts = 0
|
176 |
+
|
177 |
+
while attempts < max_attempts and n_conns < 5:
|
178 |
+
# Pick random nodes
|
179 |
+
src = int(jax.random.randint(keys[2], (), 0, genome.n_nodes))
|
180 |
+
dst = int(jax.random.randint(keys[3], (), 0, genome.n_nodes))
|
181 |
+
|
182 |
+
# Add connection if valid and not already present
|
183 |
+
if src != dst and (src, dst) not in genome.connections:
|
184 |
+
weight = jax.random.normal(keys[4]) * 0.5
|
185 |
+
genome.add_connection(src, dst, weight)
|
186 |
+
n_conns += 1
|
187 |
+
attempts += 1
|
188 |
+
|
189 |
+
# Mutate weights
|
190 |
+
if jax.random.uniform(keys[2]) < self.config['weight_mutate_prob']:
|
191 |
+
for conn in list(genome.connections.keys()):
|
192 |
+
if genome.connections[conn]: # Only mutate enabled connections
|
193 |
+
if jax.random.uniform(keys[3]) < self.config['weight_replace_prob']:
|
194 |
+
# Reset weight
|
195 |
+
genome.weights[conn] = jax.random.normal(keys[4]) * self.config['weight_perturb_size']
|
196 |
+
else:
|
197 |
+
# Perturb weight
|
198 |
+
genome.weights[conn] += jax.random.normal(keys[4]) * self.config['weight_perturb_size']
|
199 |
+
|
200 |
+
# Mutate biases
|
201 |
+
if jax.random.uniform(keys[3]) < self.config['bias_mutate_prob']:
|
202 |
+
for node in list(genome.biases.keys()):
|
203 |
+
if jax.random.uniform(keys[4]) < self.config['bias_replace_prob']:
|
204 |
+
# Reset bias
|
205 |
+
genome.biases[node] = jax.random.normal(keys[5]) * self.config['bias_perturb_size']
|
206 |
+
else:
|
207 |
+
# Perturb bias
|
208 |
+
genome.biases[node] += jax.random.normal(keys[5]) * self.config['bias_perturb_size']
|
209 |
+
|
210 |
+
# Enable/disable connections
|
211 |
+
for conn in list(genome.connections.keys()):
|
212 |
+
if jax.random.uniform(keys[5]) < 0.2: # 20% chance per connection
|
213 |
+
genome.connections[conn] = not genome.connections[conn]
|
214 |
+
|
215 |
+
return genome
|
216 |
+
|
217 |
+
def get_average_nodes(self) -> float:
|
218 |
+
"""Get average number of nodes in population."""
|
219 |
+
return np.mean([g.n_nodes for g in self.population])
|
220 |
+
|
221 |
+
def get_average_connections(self) -> float:
|
222 |
+
"""Get average number of connections in population."""
|
223 |
+
return np.mean([len(g.connections) for g in self.population])
|
224 |
+
|
225 |
+
def get_activation_distribution(self) -> Dict[str, float]:
|
226 |
+
"""Get distribution of activation functions in population.
|
227 |
+
|
228 |
+
Returns:
|
229 |
+
Dictionary mapping activation function names to their frequency
|
230 |
+
"""
|
231 |
+
# For now we only use ReLU
|
232 |
+
return {'relu': 1.0}
|
233 |
+
|
234 |
+
def run_evolution(self, evaluator: Callable[[Network], float], max_generations: int,
|
235 |
+
fitness_threshold: float, reset_mutations: bool = True,
|
236 |
+
max_stagnation: int = 15, verbose: bool = True) -> Tuple[Network, float]:
|
237 |
+
"""Run the evolution process
|
238 |
+
|
239 |
+
Args:
|
240 |
+
evaluator: Function that takes a network and returns its fitness
|
241 |
+
max_generations: Maximum number of generations to run
|
242 |
+
fitness_threshold: Target fitness to achieve
|
243 |
+
reset_mutations: Whether to reset mutations when fitness improves
|
244 |
+
max_stagnation: Maximum generations without improvement before stopping
|
245 |
+
verbose: Whether to print progress
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
Tuple of (best network, best fitness)
|
249 |
+
"""
|
250 |
+
best_fitness = float('-inf')
|
251 |
+
best_network = None
|
252 |
+
stagnation_counter = 0
|
253 |
+
|
254 |
+
for generation in range(max_generations):
|
255 |
+
# Evaluate current population
|
256 |
+
fitnesses = []
|
257 |
+
for genome in self.population:
|
258 |
+
network = genome.to_network()
|
259 |
+
fitness = evaluator(network)
|
260 |
+
genome.fitness = fitness
|
261 |
+
fitnesses.append(fitness)
|
262 |
+
|
263 |
+
# Update best if improved
|
264 |
+
if fitness > best_fitness:
|
265 |
+
best_fitness = fitness
|
266 |
+
best_network = network
|
267 |
+
stagnation_counter = 0
|
268 |
+
if reset_mutations:
|
269 |
+
self.reset_innovation()
|
270 |
+
|
271 |
+
# Get statistics
|
272 |
+
avg_fitness = sum(fitnesses) / len(fitnesses)
|
273 |
+
generation_best = max(fitnesses)
|
274 |
+
|
275 |
+
# Print progress
|
276 |
+
if verbose:
|
277 |
+
print(f"\nGeneration {generation}:")
|
278 |
+
print(f" Best Fitness: {best_fitness:.2f}")
|
279 |
+
print(f" Generation Best: {generation_best:.2f}")
|
280 |
+
print(f" Average Nodes: {self.get_average_nodes():.1f}")
|
281 |
+
print(f" Average Connections: {self.get_average_connections():.1f}")
|
282 |
+
|
283 |
+
# Check for improvement
|
284 |
+
if generation_best <= best_fitness:
|
285 |
+
stagnation_counter += 1
|
286 |
+
else:
|
287 |
+
stagnation_counter = 0
|
288 |
+
|
289 |
+
# Create next generation
|
290 |
+
self.tell(fitnesses)
|
291 |
+
|
292 |
+
# Stop if stagnated too long
|
293 |
+
if stagnation_counter >= max_stagnation:
|
294 |
+
if verbose:
|
295 |
+
print(f"\nStopping: No improvement for {max_stagnation} generations")
|
296 |
+
break
|
297 |
+
|
298 |
+
if verbose:
|
299 |
+
print("\nTraining complete!")
|
300 |
+
print(f"Best fitness achieved: {best_fitness:.2f}")
|
301 |
+
|
302 |
+
return best_network, best_fitness
|