File size: 11,705 Bytes
f0a4dca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
"""BackpropNEAT implementation."""

import jax
import jax.numpy as jnp
import numpy as np
from typing import Dict, List, Tuple

from .network import Network
from .genome import Genome

class BackpropNEAT:
    """Backpropagation-based NEAT implementation."""
    
    def __init__(self, population_size=5, n_inputs=2, n_outputs=1, n_hidden=64,
                 learning_rate=0.01, beta=0.9):
        """Initialize BackpropNEAT."""
        self.population_size = population_size
        self.n_inputs = n_inputs
        self.n_outputs = n_outputs
        self.n_hidden = n_hidden
        self.learning_rate = learning_rate
        self.beta = beta
        
        # Initialize population
        self.population = []
        self.momentum_buffers = []
        
        for _ in range(population_size):
            # Create genome with skip connections
            genome = Genome(n_inputs, n_outputs, n_hidden)
            genome.add_layer_connections()  # Add standard layer connections
            genome.add_skip_connections(0.3)  # Add skip connections with 30% probability
            
            # Create network from genome
            network = Network(genome)
            self.population.append(network)
            
            # Initialize momentum buffer for this network
            momentum = {
                'weights': {k: jnp.zeros_like(w) for k, w in network.params['weights'].items()},
                'biases': jnp.zeros_like(network.params['biases']),
                'gamma': jnp.zeros_like(network.params['gamma']),
                'beta': jnp.zeros_like(network.params['beta'])
            }
            self.momentum_buffers.append(momentum)
        
        # Create train step function
        self._train_step = self._make_train_step()
        
        # Bind train step to each network
        for i, network in enumerate(self.population):
            network.population_idx = i
            # Create a bound method for each network
            network._train_step = lambda p, x, y, idx=i: self._train_step(self, p, x, y, idx)
    
    def forward(self, params, x):
        """Forward pass through network."""
        return self.population[0].forward(params, x)
    
    def _make_train_step(self):
        """Create training step function."""
        # Constants for numerical stability
        eps = 1e-7
        min_lr = 1e-6
        max_lr = 1e-2
        
        def loss_fn(params, x, y):
            """Compute loss for parameters."""
            logits = self.forward(params, x)
            
            # Binary cross entropy loss with label smoothing
            alpha = 0.1  # Label smoothing factor
            
            # Smooth labels
            y_smooth = (1 - alpha) * y + alpha * 0.5
            
            # Convert logits to probabilities
            probs = 0.5 * (logits + 1)  # Map from [-1,1] to [0,1]
            probs = jnp.clip(probs, eps, 1 - eps)
            
            # Compute loss with label smoothing
            bce_loss = -jnp.mean(
                0.5 * (1 + y_smooth) * jnp.log(probs) + 
                0.5 * (1 - y_smooth) * jnp.log(1 - probs)
            )
            
            # L2 regularization with very small weight
            l2_reg = sum(jnp.sum(w ** 2) for w in params['weights'].values())
            return bce_loss + 0.000001 * l2_reg
        
        @jax.jit
        def compute_updates(params, x, y):
            """Compute gradients and loss."""
            loss_value, grads = jax.value_and_grad(loss_fn)(params, x, y)
            return grads, loss_value
        
        def train_step(self, params, x, y, network_idx):
            """Perform single training step with momentum."""
            # Compute gradients
            grads, loss_value = compute_updates(params, x, y)
            
            # Get momentum buffer for this network
            momentum = self.momentum_buffers[network_idx]
            
            # Gradient norm for adaptive learning rate
            grad_norm = jnp.sqrt(
                sum(jnp.sum(g ** 2) for g in grads['weights'].values()) +
                jnp.sum(grads['biases'] ** 2) +
                jnp.sum(grads['gamma'] ** 2) +
                jnp.sum(grads['beta'] ** 2) +
                eps  # Add eps for numerical stability
            )
            
            # Compute adaptive learning rate
            if grad_norm > 1.0:
                effective_lr = self.learning_rate / grad_norm
            else:
                effective_lr = self.learning_rate * (1.0 + jnp.log(grad_norm + eps))
            
            # Clip learning rate to reasonable range
            effective_lr = jnp.clip(effective_lr, min_lr, max_lr)
            
            # Update weights momentum with adaptive learning rate
            new_weights = {}
            for k in params['weights'].keys():
                grad = grads['weights'][k]
                
                # Update momentum with gradient clipping
                momentum['weights'][k] = (
                    self.beta * momentum['weights'][k] +
                    (1 - self.beta) * jnp.clip(grad, -1.0, 1.0)
                )
                
                # Apply update with weight decay
                weight_decay = 0.0001 * params['weights'][k]
                new_weights[k] = params['weights'][k] - effective_lr * (
                    momentum['weights'][k] + weight_decay
                )
            
            # Update biases momentum
            momentum['biases'] = (
                self.beta * momentum['biases'] +
                (1 - self.beta) * jnp.clip(grads['biases'], -1.0, 1.0)
            )
            new_biases = params['biases'] - effective_lr * momentum['biases']
            
            # Update layer norm parameters with smaller learning rate
            ln_lr = 0.1 * effective_lr  # Slower updates for stability
            
            # Gamma (scale)
            momentum['gamma'] = (
                self.beta * momentum['gamma'] +
                (1 - self.beta) * jnp.clip(grads['gamma'], -0.1, 0.1)
            )
            new_gamma = params['gamma'] - ln_lr * momentum['gamma']
            new_gamma = jnp.clip(new_gamma, 0.1, 10.0)  # Prevent collapse
            
            # Beta (shift)
            momentum['beta'] = (
                self.beta * momentum['beta'] +
                (1 - self.beta) * jnp.clip(grads['beta'], -0.1, 0.1)
            )
            new_beta = params['beta'] - ln_lr * momentum['beta']
            
            return {
                'weights': new_weights,
                'biases': new_biases,
                'gamma': new_gamma,
                'beta': new_beta
            }, loss_value
        
        return train_step
    
    def _mutate_genome(self, genome: Genome) -> Genome:
        """Mutate genome architecture."""
        new_genome = genome.copy()
        
        # Mutate weights and biases
        for key in list(new_genome.params['weights'].keys()):
            if np.random.random() < 0.1:
                new_genome.params['weights'][key] += np.random.normal(0, 0.2)
        
        for key in list(new_genome.params['biases'].keys()):
            if np.random.random() < 0.1:
                new_genome.params['biases'][key] += np.random.normal(0, 0.2)
        
        return new_genome
    
    def _select_parent(self, fitnesses: List[float]) -> int:
        """Select parent using tournament selection."""
        # Tournament selection
        tournament_size = 3
        best_idx = np.random.randint(len(fitnesses))
        best_fitness = fitnesses[best_idx]
        
        for _ in range(tournament_size - 1):
            idx = np.random.randint(len(fitnesses))
            if fitnesses[idx] > best_fitness:
                best_idx = idx
                best_fitness = fitnesses[idx]
        
        return best_idx
    
    def _compute_fitness(self, network: Network, x: jnp.ndarray, y: jnp.ndarray,
                        n_epochs: int = 100, batch_size: int = 32) -> float:
        """Compute fitness of network."""
        n_samples = x.shape[0]
        best_loss = float('inf')
        best_accuracy = 0.0
        
        # Initial prediction
        initial_pred = network.predict(x)
        initial_acc = float(jnp.mean((initial_pred == y)))
        
        # Train network
        no_improve = 0
        for epoch in range(n_epochs):
            # Shuffle data
            perm = np.random.permutation(n_samples)
            x_shuffled = x[perm]
            y_shuffled = y[perm]
            
            # Train in batches
            epoch_losses = []
            for i in range(0, n_samples, batch_size):
                batch_x = x_shuffled[i:min(i + batch_size, n_samples)]
                batch_y = y_shuffled[i:min(i + batch_size, n_samples)]
                
                # Train step
                network.params, loss = network._train_step(network.params, batch_x, batch_y)
                epoch_losses.append(float(loss))
            
            # Update best loss
            avg_loss = float(np.mean(epoch_losses))
            if avg_loss < best_loss:
                best_loss = avg_loss
                no_improve = 0
            else:
                no_improve += 1
            
            # Compute accuracy
            predictions = network.predict(x)
            accuracy = float(jnp.mean((predictions == y)))
            best_accuracy = max(best_accuracy, accuracy)
            
            # Print progress every 10 epochs
            if epoch % 10 == 0:
                print(f"Epoch {epoch}: Loss = {avg_loss:.4f}, Accuracy = {accuracy:.4f}")
            
            # Early stopping if good accuracy or no improvement
            if accuracy > 0.95 or no_improve >= 10:
                print(f"Early stopping at epoch {epoch}")
                print(f"Final accuracy: {accuracy:.4f}")
                break
        
        # Print improvement
        print(f"Network improved from {initial_acc:.4f} to {best_accuracy:.4f}")
        
        # Fitness based on accuracy
        fitness = best_accuracy
        
        return float(fitness)
    
    def evolve(self, x: jnp.ndarray, y: jnp.ndarray, n_generations: int = 50) -> Network:
        """Evolve network architectures."""
        for generation in range(n_generations):
            print(f"\nGeneration {generation}")
            
            # Evaluate current population
            fitnesses = []
            for network in self.population:
                fitness = self._compute_fitness(network, x, y)
                fitnesses.append(fitness)
                
                # Update best network
                if fitness > self.best_fitness:
                    self.best_fitness = fitness
                    self.best_network = Network(network.genome.copy())
                    print(f"New best fitness: {fitness:.4f}")
            
            # Create new population through selection and mutation
            new_population = []
            
            # Keep best network (elitism)
            best_idx = np.argmax(fitnesses)
            new_population.append(Network(self.population[best_idx].genome.copy()))
            
            # Create rest of population
            while len(new_population) < self.population_size:
                # Select parent
                parent_idx = self._select_parent(fitnesses)
                parent = self.population[parent_idx].genome
                
                # Create child through mutation
                child_genome = self._mutate_genome(parent)
                new_population.append(Network(child_genome))
            
            self.population = new_population
        
        return self.best_network