eyad-silx commited on
Commit
c536443
·
verified ·
1 Parent(s): ecccd48

Upload neat\analysis.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. neat//analysis.py +383 -0
neat//analysis.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Analysis utilities for neural networks.
2
+
3
+ This module provides functions for analyzing neural network architectures,
4
+ including complexity measures and structural properties.
5
+ """
6
+
7
+ import numpy as np
8
+ import networkx as nx
9
+ import matplotlib.pyplot as plt
10
+ from typing import Dict, Tuple, Union, Optional, List, Any
11
+ from .network import Network
12
+ from .genome import Genome
13
+ from collections import defaultdict
14
+ import os
15
+
16
+ def analyze_network_complexity(network: Network) -> Dict[str, Any]:
17
+ """Analyze the complexity of a neural network.
18
+
19
+ Computes various complexity metrics including:
20
+ 1. Number of nodes by type (input, hidden, output)
21
+ 2. Number of connections
22
+ 3. Network density
23
+ 4. Activation functions used
24
+
25
+ Args:
26
+ network: Network instance to analyze
27
+
28
+ Returns:
29
+ Dictionary containing complexity metrics
30
+ """
31
+ # Get network structure
32
+ genome = network.genome
33
+
34
+ # Count nodes by type
35
+ n_input = genome.input_size
36
+ n_hidden = len(genome.hidden_nodes)
37
+ n_output = genome.output_size
38
+
39
+ # Count connections
40
+ n_connections = len(genome.connections)
41
+
42
+ # Calculate connectivity density
43
+ n_possible = (n_input + n_hidden + n_output) * (n_hidden + n_output) # No connections to input
44
+ density = n_connections / n_possible if n_possible > 0 else 0
45
+
46
+ # Get activation functions (currently only ReLU)
47
+ activation_functions = {'relu': n_hidden + n_output} # All nodes use ReLU
48
+
49
+ return {
50
+ 'n_input': n_input,
51
+ 'n_hidden': n_hidden,
52
+ 'n_output': n_output,
53
+ 'n_connections': n_connections,
54
+ 'density': density,
55
+ 'activation_functions': activation_functions
56
+ }
57
+
58
+ def get_network_stats(network: Network) -> Dict[str, float]:
59
+ """Get statistical measures of network properties.
60
+
61
+ Computes various statistics about the network structure and parameters:
62
+ - Number of nodes and connections
63
+ - Average and std of weights and biases
64
+ - Network density and depth
65
+
66
+ Args:
67
+ network: Network instance to analyze
68
+
69
+ Returns:
70
+ Dictionary containing network statistics
71
+ """
72
+ stats = {}
73
+
74
+ # Node counts
75
+ stats['n_nodes'] = network.n_nodes
76
+ stats['n_hidden'] = network.n_nodes - network.input_size - network.output_size
77
+
78
+ # Connection stats
79
+ weights = np.array(list(network.weights.values()))
80
+ stats['n_connections'] = len(weights)
81
+ stats['weight_mean'] = float(np.mean(weights))
82
+ stats['weight_std'] = float(np.std(weights))
83
+
84
+ # Bias stats
85
+ biases = np.array(list(network.bias.values()))
86
+ stats['n_biases'] = len(biases)
87
+ stats['bias_mean'] = float(np.mean(biases))
88
+ stats['bias_std'] = float(np.std(biases))
89
+
90
+ # Connectivity
91
+ n_possible = network.n_nodes * (network.n_nodes - 1)
92
+ stats['density'] = len(weights) / n_possible if n_possible > 0 else 0
93
+
94
+ # Compute approximate network depth
95
+ weight_matrix = network.weight_matrix
96
+ depth = 0
97
+ visited = set(range(network.input_size))
98
+ frontier = visited.copy()
99
+
100
+ while frontier and depth < network.n_nodes:
101
+ next_frontier = set()
102
+ for node in frontier:
103
+ for next_node in range(network.n_nodes):
104
+ if weight_matrix[node, next_node] != 0 and next_node not in visited:
105
+ next_frontier.add(next_node)
106
+ visited.add(next_node)
107
+ frontier = next_frontier
108
+ if frontier:
109
+ depth += 1
110
+
111
+ stats['depth'] = depth
112
+
113
+ return stats
114
+
115
+ def visualize_network_architecture(network: Network, save_path: Optional[str] = None):
116
+ """Visualize the network architecture using networkx.
117
+
118
+ Creates a layered visualization of the neural network with:
119
+ - Input nodes in red (leftmost layer)
120
+ - Hidden nodes in blue (middle layer)
121
+ - Output nodes in green (rightmost layer)
122
+ - Connections shown as arrows with thickness proportional to weight
123
+
124
+ Args:
125
+ network: Network instance to visualize
126
+ save_path: Optional path to save the visualization
127
+
128
+ Returns:
129
+ matplotlib figure object or None if visualization fails
130
+ """
131
+ try:
132
+ import networkx as nx
133
+ import matplotlib.pyplot as plt
134
+
135
+ genome = network.genome
136
+ G = nx.DiGraph()
137
+
138
+ # Calculate layout parameters
139
+ n_inputs = len([node for node in genome.node_genes.values() if node.node_type == 'input'])
140
+ n_outputs = len([node for node in genome.node_genes.values() if node.node_type == 'output'])
141
+ hidden_nodes = [node.node_id for node in genome.node_genes.values() if node.node_type == 'hidden']
142
+ n_hidden = len(hidden_nodes)
143
+
144
+ # Layout parameters
145
+ node_spacing = 1.0 # Vertical spacing between nodes in same layer
146
+ layer_spacing = 2.0 # Horizontal spacing between layers
147
+
148
+ # Initialize position and color dictionaries
149
+ pos = {}
150
+ node_colors = {}
151
+
152
+ # Add input nodes (leftmost layer)
153
+ input_start_y = -(n_inputs - 1) * node_spacing / 2 # Center vertically
154
+ input_nodes = [node.node_id for node in genome.node_genes.values() if node.node_type == 'input']
155
+ for i, node_idx in enumerate(input_nodes):
156
+ pos[node_idx] = (0, input_start_y + i * node_spacing)
157
+ node_colors[node_idx] = 'lightcoral' # Light red for input nodes
158
+
159
+ # Add hidden nodes (middle layer)
160
+ if hidden_nodes:
161
+ hidden_start_y = -(n_hidden - 1) * node_spacing / 2 # Center vertically
162
+ for i, node_idx in enumerate(hidden_nodes):
163
+ pos[node_idx] = (layer_spacing, hidden_start_y + i * node_spacing)
164
+ node_colors[node_idx] = 'lightblue' # Light blue for hidden nodes
165
+
166
+ # Add output nodes (rightmost layer)
167
+ output_start_y = -(n_outputs - 1) * node_spacing / 2 # Center vertically
168
+ output_nodes = [node.node_id for node in genome.node_genes.values() if node.node_type == 'output']
169
+ for i, node_idx in enumerate(output_nodes):
170
+ pos[node_idx] = (2 * layer_spacing, output_start_y + i * node_spacing)
171
+ node_colors[node_idx] = 'lightgreen' # Light green for output nodes
172
+
173
+ # Add bias node if present
174
+ bias_node = [node.node_id for node in genome.node_genes.values() if node.node_type == 'bias']
175
+ if bias_node:
176
+ pos[bias_node[0]] = (0, input_start_y - node_spacing) # Place below input nodes
177
+ node_colors[bias_node[0]] = 'yellow' # Yellow for bias node
178
+
179
+ # Add all nodes to graph and ensure they have colors and positions
180
+ for node_id in genome.node_genes:
181
+ G.add_node(node_id)
182
+ if node_id not in node_colors: # Assign default color if not already assigned
183
+ node_type = genome.node_genes[node_id].node_type
184
+ if node_type == 'input':
185
+ node_colors[node_id] = 'lightcoral'
186
+ elif node_type == 'hidden':
187
+ node_colors[node_id] = 'lightblue'
188
+ elif node_type == 'output':
189
+ node_colors[node_id] = 'lightgreen'
190
+ elif node_type == 'bias':
191
+ node_colors[node_id] = 'yellow'
192
+ else:
193
+ node_colors[node_id] = 'gray' # Default color for unknown types
194
+
195
+ # Ensure node has a position
196
+ if node_id not in pos:
197
+ # Place unknown nodes in middle layer
198
+ pos[node_id] = (layer_spacing, 0)
199
+
200
+ # Add connections
201
+ for conn in genome.connection_genes:
202
+ if conn.enabled:
203
+ # Scale connection width by weight
204
+ width = abs(conn.weight) * 2.0
205
+ # Use red for negative weights, green for positive
206
+ color = 'red' if conn.weight < 0 else 'green'
207
+ alpha = min(abs(conn.weight), 1.0) # Transparency based on weight magnitude
208
+ G.add_edge(conn.source, conn.target, weight=width, color=color, alpha=alpha)
209
+
210
+ # Set up the plot
211
+ fig = plt.figure(figsize=(12, 8))
212
+
213
+ # Draw nodes with colors
214
+ nx.draw_networkx_nodes(G, pos, node_color=[node_colors[node] for node in G.nodes()],
215
+ node_size=800, alpha=0.8)
216
+
217
+ # Draw edges with width proportional to weight
218
+ edge_weights = [G.get_edge_data(edge[0], edge[1])['weight'] for edge in G.edges()]
219
+ if edge_weights: # Only draw edges if there are any
220
+ max_weight = max(edge_weights)
221
+ normalized_weights = [3 * w / max_weight for w in edge_weights] # Scale for visibility
222
+ nx.draw_networkx_edges(G, pos, edge_color=[G.get_edge_data(edge[0], edge[1])['color'] for edge in G.edges()],
223
+ width=normalized_weights,
224
+ alpha=[G.get_edge_data(edge[0], edge[1])['alpha'] for edge in G.edges()],
225
+ arrows=True, arrowsize=20)
226
+
227
+ # Add node labels
228
+ nx.draw_networkx_labels(G, pos, font_size=10)
229
+
230
+ plt.title("Neural Network Architecture")
231
+ plt.axis('off') # Hide axes
232
+
233
+ if save_path:
234
+ # Ensure the directory exists
235
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
236
+ plt.savefig(save_path, bbox_inches='tight', dpi=300)
237
+ plt.close(fig) # Close the figure to free memory
238
+
239
+ return fig
240
+
241
+ except Exception as e:
242
+ print(f"Error visualizing network: {str(e)}")
243
+ return None
244
+
245
+ def plot_activation_distribution(population: List[Genome], save_path: Optional[str] = None):
246
+ """Plot the distribution of node types in the population.
247
+
248
+ Args:
249
+ population: List of genomes in the population
250
+ save_path: Optional path to save the plot
251
+
252
+ Returns:
253
+ matplotlib figure object or None if plotting fails
254
+ """
255
+ try:
256
+ # Count nodes by type for each genome
257
+ node_type_counts = defaultdict(int)
258
+ for genome in population:
259
+ node_type_counts['input'] += genome.input_size
260
+ node_type_counts['hidden'] += len(genome.hidden_nodes)
261
+ node_type_counts['output'] += genome.output_size
262
+
263
+ if not node_type_counts:
264
+ print("No nodes found in population")
265
+ return None
266
+
267
+ # Create bar plot
268
+ fig = plt.figure(figsize=(10, 6))
269
+
270
+ # Get node types and counts
271
+ node_types = list(node_type_counts.keys())
272
+ counts = list(node_type_counts.values())
273
+
274
+ # Create bars with different colors
275
+ colors = {'input': 'lightcoral', 'hidden': 'lightblue', 'output': 'lightgreen'}
276
+ plt.bar(node_types, counts, color=[colors[t] for t in node_types], alpha=0.7)
277
+
278
+ # Customize plot
279
+ plt.title('Distribution of Node Types in Population')
280
+ plt.xlabel('Node Type')
281
+ plt.ylabel('Total Count')
282
+
283
+ # Add count labels on top of bars
284
+ for i, count in enumerate(counts):
285
+ plt.text(i, count, str(count), ha='center', va='bottom')
286
+
287
+ plt.tight_layout()
288
+
289
+ # Save or display
290
+ if save_path:
291
+ # Ensure the directory exists
292
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
293
+ plt.savefig(save_path, bbox_inches='tight', dpi=300)
294
+ plt.close(fig) # Close the figure to free memory
295
+
296
+ return fig
297
+
298
+ except Exception as e:
299
+ print(f"Error plotting activation distribution: {str(e)}")
300
+ return None
301
+
302
+ def analyze_evolution_trends(stats: Dict, save_dir: str) -> None:
303
+ """Analyze and plot evolution trends from training history.
304
+
305
+ Args:
306
+ stats: Dictionary containing training statistics
307
+ save_dir: Directory to save plots
308
+ """
309
+ try:
310
+ # Create plots directory if it doesn't exist
311
+ os.makedirs(save_dir, exist_ok=True)
312
+
313
+ # Check if we have any stats to plot
314
+ if not stats or 'mean_fitness' not in stats or not stats['mean_fitness']:
315
+ print("No evolution stats available yet")
316
+ return
317
+
318
+ # Extract metrics over generations
319
+ generations = list(range(len(stats['mean_fitness'])))
320
+ if not generations: # No data points yet
321
+ print("No generations completed yet")
322
+ return
323
+
324
+ metrics = {
325
+ 'Fitness': {
326
+ 'mean': stats.get('mean_fitness', []),
327
+ 'best': stats.get('best_fitness', [])
328
+ },
329
+ 'Complexity': {
330
+ 'mean': stats.get('mean_complexity', []),
331
+ 'best': stats.get('best_complexity', [])
332
+ }
333
+ }
334
+
335
+ # Plot each metric
336
+ for metric_name, metric_data in metrics.items():
337
+ # Verify we have data for this metric
338
+ if not metric_data['mean'] or not metric_data['best']:
339
+ print(f"No data available for {metric_name}")
340
+ continue
341
+
342
+ # Verify data lengths match
343
+ if len(generations) != len(metric_data['mean']) or len(generations) != len(metric_data['best']):
344
+ print(f"Data length mismatch for {metric_name}")
345
+ continue
346
+
347
+ fig = plt.figure(figsize=(10, 6))
348
+
349
+ # Plot mean and best values
350
+ plt.plot(generations, metric_data['mean'], label=f'Mean {metric_name}', alpha=0.7)
351
+ plt.plot(generations, metric_data['best'], label=f'Best {metric_name}', alpha=0.7)
352
+
353
+ plt.title(f'{metric_name} Over Generations')
354
+ plt.xlabel('Generation')
355
+ plt.ylabel(metric_name)
356
+ plt.legend()
357
+ plt.grid(True, alpha=0.3)
358
+
359
+ # Save plot
360
+ save_path = os.path.join(save_dir, f'{metric_name.lower()}_trends.png')
361
+ plt.savefig(save_path, bbox_inches='tight', dpi=300)
362
+ plt.close(fig) # Close the figure to free memory
363
+
364
+ # Plot species counts if available
365
+ if 'n_species' in stats and stats['n_species']:
366
+ n_species = stats['n_species']
367
+ if len(generations) == len(n_species): # Verify data length matches
368
+ fig = plt.figure(figsize=(10, 6))
369
+ plt.plot(generations, n_species, label='Number of Species', alpha=0.7)
370
+ plt.title('Number of Species Over Generations')
371
+ plt.xlabel('Generation')
372
+ plt.ylabel('Number of Species')
373
+ plt.legend()
374
+ plt.grid(True, alpha=0.3)
375
+
376
+ save_path = os.path.join(save_dir, 'species_trends.png')
377
+ plt.savefig(save_path, bbox_inches='tight', dpi=300)
378
+ plt.close(fig) # Close the figure to free memory
379
+ else:
380
+ print("Species count data length mismatch")
381
+
382
+ except Exception as e:
383
+ print(f"Error analyzing evolution trends: {str(e)}")