eyad-silx commited on
Commit
61162bb
·
verified ·
1 Parent(s): 3604754

Upload backprop_train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. backprop_train.py +360 -0
backprop_train.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Train BackpropNEAT on Spiral dataset."""
2
+
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import jax.numpy as jnp
6
+ import jax
7
+ import os
8
+ import json
9
+ from datetime import datetime
10
+ from sklearn.model_selection import train_test_split
11
+ from sklearn.utils import shuffle
12
+
13
+ from neat.backprop_neat import BackpropNEAT
14
+ from neat.datasets import generate_spiral_dataset
15
+ from neat.network import Network
16
+ from neat.genome import Genome
17
+
18
+ class NetworkLogger:
19
+ """Logger for tracking network evolution."""
20
+
21
+ def __init__(self, output_dir: str):
22
+ self.output_dir = output_dir
23
+ self.log_file = os.path.join(output_dir, "evolution_log.json")
24
+ self.history = []
25
+
26
+ def log_network(self, epoch: int, network, loss: float, accuracy: float):
27
+ """Log network state."""
28
+ network_state = {
29
+ 'epoch': epoch,
30
+ 'loss': float(loss),
31
+ 'accuracy': float(accuracy),
32
+ 'n_nodes': network.genome.n_nodes,
33
+ 'n_connections': len(network.genome.connections),
34
+ 'complexity_score': self.calculate_complexity(network),
35
+ 'structure': self.get_network_structure(network),
36
+ 'timestamp': datetime.now().isoformat()
37
+ }
38
+ self.history.append(network_state)
39
+
40
+ # Save to file
41
+ with open(self.log_file, 'w') as f:
42
+ json.dump(self.history, f, indent=2)
43
+
44
+ def calculate_complexity(self, network):
45
+ """Calculate network complexity score."""
46
+ n_nodes = network.genome.n_nodes
47
+ n_connections = len(network.genome.connections)
48
+ return n_nodes * 0.5 + n_connections
49
+
50
+ def get_network_structure(self, network):
51
+ """Get detailed network structure."""
52
+ connections = []
53
+ for (src, dst), weight in network.genome.connections.items():
54
+ connections.append({
55
+ 'source': int(src),
56
+ 'target': int(dst),
57
+ 'weight': float(weight)
58
+ })
59
+ return {
60
+ 'input_size': network.genome.input_size,
61
+ 'output_size': network.genome.output_size,
62
+ 'hidden_nodes': network.genome.n_nodes - network.genome.input_size - network.genome.output_size,
63
+ 'connections': connections
64
+ }
65
+
66
+ def plot_evolution(self, save_path: str):
67
+ """Plot network evolution metrics."""
68
+ epochs = [log['epoch'] for log in self.history]
69
+ accuracies = [log['accuracy'] for log in self.history]
70
+ complexities = [log['complexity_score'] for log in self.history]
71
+ losses = [log['loss'] for log in self.history]
72
+
73
+ fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(12, 12))
74
+
75
+ # Plot accuracy
76
+ ax1.plot(epochs, accuracies, 'b-', label='Accuracy')
77
+ ax1.set_ylabel('Accuracy')
78
+ ax1.set_title('Network Evolution')
79
+ ax1.grid(True)
80
+ ax1.legend()
81
+
82
+ # Plot complexity
83
+ ax2.plot(epochs, complexities, 'r-', label='Complexity Score')
84
+ ax2.set_ylabel('Complexity Score')
85
+ ax2.grid(True)
86
+ ax2.legend()
87
+
88
+ # Plot loss
89
+ ax3.plot(epochs, losses, 'g-', label='Loss')
90
+ ax3.set_ylabel('Loss')
91
+ ax3.set_xlabel('Epoch')
92
+ ax3.grid(True)
93
+ ax3.legend()
94
+
95
+ plt.tight_layout()
96
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
97
+ plt.close()
98
+
99
+ def visualize_dataset(X, y, network=None, title=None, save_path=None):
100
+ """Visualize dataset with decision boundary."""
101
+ plt.figure(figsize=(10, 8))
102
+
103
+ if network is not None:
104
+ # Create mesh grid
105
+ x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
106
+ y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
107
+ xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100),
108
+ np.linspace(y_min, y_max, 100))
109
+
110
+ # Make predictions
111
+ X_mesh = jnp.array(np.c_[xx.ravel(), yy.ravel()], dtype=jnp.float32)
112
+ Z = network.predict(X_mesh)
113
+ Z = Z.reshape(xx.shape)
114
+
115
+ # Plot decision boundary
116
+ plt.contourf(xx, yy, Z, alpha=0.4, cmap='RdYlBu')
117
+
118
+ plt.scatter(X[y == 1, 0], X[y == 1, 1], c='red', label='Class 1')
119
+ plt.scatter(X[y == -1, 0], X[y == -1, 1], c='blue', label='Class -1')
120
+ plt.grid(True)
121
+ plt.legend()
122
+ plt.title(title or 'Dataset')
123
+ plt.xlabel('X1')
124
+ plt.ylabel('X2')
125
+
126
+ if save_path:
127
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
128
+ print(f"Saved plot to {save_path}")
129
+ else:
130
+ plt.show()
131
+ plt.close()
132
+
133
+ def train_network(network, X, y, n_epochs=300, batch_size=32, patience=50):
134
+ """Train a single network."""
135
+ print("Starting network training...")
136
+ print(f"Input shape: {X.shape}, Output shape: {y.shape}")
137
+ print(f"Network params: {network.params['weights'].keys()}")
138
+
139
+ n_samples = len(X)
140
+ n_batches = n_samples // batch_size
141
+ best_accuracy = 0.0
142
+ patience_counter = 0
143
+ best_params = None
144
+
145
+ # Convert to JAX arrays
146
+ print("Converting to JAX arrays...")
147
+ X = jnp.array(X, dtype=jnp.float32)
148
+ y = jnp.array(y, dtype=jnp.float32)
149
+
150
+ # Learning rate schedule
151
+ base_lr = 0.01
152
+ warmup_epochs = 5
153
+
154
+ print(f"\nTraining for {n_epochs} epochs with {n_batches} batches per epoch")
155
+ print(f"Batch size: {batch_size}, Patience: {patience}")
156
+
157
+ for epoch in range(n_epochs):
158
+ try:
159
+ # Shuffle data
160
+ perm = np.random.permutation(n_samples)
161
+ X = X[perm]
162
+ y = y[perm]
163
+
164
+ # Adjust learning rate with warmup and cosine decay
165
+ if epoch < warmup_epochs:
166
+ lr = base_lr * (epoch + 1) / warmup_epochs
167
+ else:
168
+ # Cosine decay with restarts
169
+ cycle_length = 50
170
+ cycle = (epoch - warmup_epochs) // cycle_length
171
+ t = (epoch - warmup_epochs) % cycle_length
172
+ lr = base_lr * 0.5 * (1 + np.cos(t * np.pi / cycle_length))
173
+ # Add small restart bump every cycle
174
+ if t == 0:
175
+ lr = base_lr * (0.9 ** cycle)
176
+
177
+ epoch_loss = 0.0
178
+
179
+ # Train on mini-batches
180
+ for i in range(n_batches):
181
+ start_idx = i * batch_size
182
+ end_idx = start_idx + batch_size
183
+ X_batch = X[start_idx:end_idx]
184
+ y_batch = y[start_idx:end_idx]
185
+
186
+ try:
187
+ # Update network parameters
188
+ network.params, loss = network._train_step(
189
+ network.params,
190
+ X_batch,
191
+ y_batch
192
+ )
193
+ epoch_loss += loss
194
+ except Exception as e:
195
+ print(f"Error in batch {i}: {str(e)}")
196
+ print(f"X_batch shape: {X_batch.shape}, y_batch shape: {y_batch.shape}")
197
+ raise e
198
+
199
+ # Compute training accuracy
200
+ predictions = network.predict(X)
201
+ train_accuracy = np.mean((predictions > 0) == (y > 0))
202
+
203
+ # Early stopping check
204
+ if train_accuracy > best_accuracy:
205
+ best_accuracy = train_accuracy
206
+ best_params = {k: v.copy() for k, v in network.params.items()}
207
+ patience_counter = 0
208
+ else:
209
+ patience_counter += 1
210
+
211
+ # Print progress every epoch
212
+ print(f"Epoch {epoch}: Train Acc = {train_accuracy:.4f}, Loss = {epoch_loss/n_batches:.4f}, LR = {lr:.6f}")
213
+
214
+ # Early stopping
215
+ if patience_counter >= patience:
216
+ print(f"Early stopping at epoch {epoch}")
217
+ break
218
+
219
+ except Exception as e:
220
+ print(f"Error in epoch {epoch}: {str(e)}")
221
+ raise e
222
+
223
+ # Restore best parameters
224
+ if best_params is not None:
225
+ network.params = best_params
226
+ print(f"\nRestored best parameters with accuracy: {best_accuracy:.4f}")
227
+
228
+ return best_accuracy
229
+
230
+ def plot_decision_boundary(network, X, y, save_path):
231
+ """Plot decision boundary with multiple views."""
232
+ fig, axes = plt.subplots(2, 2, figsize=(15, 15))
233
+
234
+ # Cartesian View
235
+ x_min, x_max = X[:, 0].min() - 0.1, X[:, 0].max() + 0.1
236
+ y_min, y_max = X[:, 1].min() - 0.1, X[:, 1].max() + 0.1
237
+ xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100),
238
+ np.linspace(y_min, y_max, 100))
239
+
240
+ # Create all features for prediction
241
+ r = np.sqrt(xx**2 + yy**2)
242
+ theta = np.arctan2(yy, xx)
243
+ theta = np.unwrap(theta)
244
+ dr_dtheta = r / theta
245
+
246
+ # Normalize features
247
+ x_norm = xx.ravel() / np.max(np.abs(X[:, 0]))
248
+ y_norm = yy.ravel() / np.max(np.abs(X[:, 1]))
249
+ r_norm = r.ravel() / np.max(X[:, 2] * np.max(np.abs(X[:, 0])))
250
+ theta_norm = theta.ravel() / (6 * np.pi)
251
+ dr_norm = dr_dtheta.ravel() / np.max(np.abs(X[:, 4]))
252
+
253
+ # Make predictions
254
+ X_mesh = jnp.array(np.column_stack([
255
+ x_norm, y_norm, r_norm, theta_norm, dr_norm
256
+ ]), dtype=jnp.float32)
257
+ Z = network.predict(X_mesh)
258
+ Z = Z.reshape(xx.shape)
259
+
260
+ # Plot Cartesian view
261
+ axes[0,0].contourf(xx, yy, Z, alpha=0.4, cmap='RdYlBu')
262
+ axes[0,0].scatter(X[:, 0] * np.max(np.abs(X[:, 0])),
263
+ X[:, 1] * np.max(np.abs(X[:, 1])),
264
+ c=['red' if label == 1 else 'blue' for label in y],
265
+ alpha=0.6)
266
+ axes[0,0].set_title('Cartesian View')
267
+ axes[0,0].grid(True)
268
+
269
+ # Plot Polar view (θ vs r)
270
+ axes[0,1].scatter(X[:, 3] * 6 * np.pi, # Denormalize theta
271
+ X[:, 2] * np.max(np.abs(X[:, 0])), # Denormalize radius
272
+ c=['red' if label == 1 else 'blue' for label in y],
273
+ alpha=0.6)
274
+ axes[0,1].set_title('Polar View (θ vs r)')
275
+ axes[0,1].grid(True)
276
+
277
+ # Plot dr/dθ vs θ
278
+ axes[1,0].scatter(X[:, 3] * 6 * np.pi, # theta
279
+ X[:, 4] * np.max(np.abs(X[:, 4])), # dr/dtheta
280
+ c=['red' if label == 1 else 'blue' for label in y],
281
+ alpha=0.6)
282
+ axes[1,0].set_title('Spiral Tightness (dr/dθ vs θ)')
283
+ axes[1,0].grid(True)
284
+
285
+ # Plot r vs dr/dθ
286
+ axes[1,1].scatter(X[:, 4] * np.max(np.abs(X[:, 4])), # dr/dtheta
287
+ X[:, 2] * np.max(np.abs(X[:, 0])), # radius
288
+ c=['red' if label == 1 else 'blue' for label in y],
289
+ alpha=0.6)
290
+ axes[1,1].set_title('Growth Rate (r vs dr/dθ)')
291
+ axes[1,1].grid(True)
292
+
293
+ plt.tight_layout()
294
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
295
+ plt.close()
296
+
297
+ def main():
298
+ """Main training loop."""
299
+ print("\nTraining on Spiral dataset...")
300
+
301
+ # Generate spiral dataset
302
+ X, y = generate_spiral_dataset(n_points=1000, noise=0.1)
303
+
304
+ # Split data
305
+ X_train, X_val, y_train, y_val = train_test_split(
306
+ X, y, test_size=0.2, random_state=42
307
+ )
308
+
309
+ # Initialize BackpropNEAT with smaller network
310
+ n_features = X.shape[1]
311
+ neat = BackpropNEAT(
312
+ n_inputs=n_features,
313
+ n_outputs=1,
314
+ n_hidden=32, # Reduced hidden layer size
315
+ population_size=5,
316
+ learning_rate=0.01,
317
+ beta=0.9
318
+ )
319
+
320
+ # Training parameters
321
+ n_epochs = 300
322
+ batch_size = 32
323
+ patience = 30 # Reduced patience
324
+
325
+ # Train each network in the population
326
+ best_network = None
327
+ best_val_acc = 0.0
328
+
329
+ for i, network in enumerate(neat.population):
330
+ print(f"\nTraining network {i+1}/{len(neat.population)}...")
331
+
332
+ # Train network
333
+ train_accuracy = train_network(
334
+ network,
335
+ X_train,
336
+ y_train,
337
+ n_epochs=n_epochs,
338
+ batch_size=batch_size,
339
+ patience=patience
340
+ )
341
+
342
+ # Evaluate on validation set
343
+ val_preds = network.predict(X_val)
344
+ val_accuracy = np.mean((val_preds > 0) == (y_val > 0))
345
+
346
+ print(f"Network {i+1} - Train Acc: {train_accuracy:.4f}, Val Acc: {val_accuracy:.4f}")
347
+
348
+ # Update best network
349
+ if val_accuracy > best_val_acc:
350
+ best_val_acc = val_accuracy
351
+ best_network = network
352
+
353
+ # Plot decision boundary for best network
354
+ if best_network is not None:
355
+ plot_path = "spiral_decision_boundary.png"
356
+ plot_decision_boundary(best_network, X, y, plot_path)
357
+ print(f"\nDecision boundary plot saved to {plot_path}")
358
+
359
+ if __name__ == "__main__":
360
+ main()