File size: 15,233 Bytes
c536443
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
"""Analysis utilities for neural networks.

This module provides functions for analyzing neural network architectures,
including complexity measures and structural properties.
"""

import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from typing import Dict, Tuple, Union, Optional, List, Any
from .network import Network
from .genome import Genome
from collections import defaultdict
import os

def analyze_network_complexity(network: Network) -> Dict[str, Any]:
    """Analyze the complexity of a neural network.
    
    Computes various complexity metrics including:
    1. Number of nodes by type (input, hidden, output)
    2. Number of connections
    3. Network density
    4. Activation functions used
    
    Args:
        network: Network instance to analyze
        
    Returns:
        Dictionary containing complexity metrics
    """
    # Get network structure
    genome = network.genome
    
    # Count nodes by type
    n_input = genome.input_size
    n_hidden = len(genome.hidden_nodes)
    n_output = genome.output_size
    
    # Count connections
    n_connections = len(genome.connections)
    
    # Calculate connectivity density
    n_possible = (n_input + n_hidden + n_output) * (n_hidden + n_output)  # No connections to input
    density = n_connections / n_possible if n_possible > 0 else 0
    
    # Get activation functions (currently only ReLU)
    activation_functions = {'relu': n_hidden + n_output}  # All nodes use ReLU
    
    return {
        'n_input': n_input,
        'n_hidden': n_hidden,
        'n_output': n_output,
        'n_connections': n_connections,
        'density': density,
        'activation_functions': activation_functions
    }

def get_network_stats(network: Network) -> Dict[str, float]:
    """Get statistical measures of network properties.
    
    Computes various statistics about the network structure and parameters:
    - Number of nodes and connections
    - Average and std of weights and biases
    - Network density and depth
    
    Args:
        network: Network instance to analyze
        
    Returns:
        Dictionary containing network statistics
    """
    stats = {}
    
    # Node counts
    stats['n_nodes'] = network.n_nodes
    stats['n_hidden'] = network.n_nodes - network.input_size - network.output_size
    
    # Connection stats
    weights = np.array(list(network.weights.values()))
    stats['n_connections'] = len(weights)
    stats['weight_mean'] = float(np.mean(weights))
    stats['weight_std'] = float(np.std(weights))
    
    # Bias stats
    biases = np.array(list(network.bias.values()))
    stats['n_biases'] = len(biases)
    stats['bias_mean'] = float(np.mean(biases))
    stats['bias_std'] = float(np.std(biases))
    
    # Connectivity
    n_possible = network.n_nodes * (network.n_nodes - 1)
    stats['density'] = len(weights) / n_possible if n_possible > 0 else 0
    
    # Compute approximate network depth
    weight_matrix = network.weight_matrix
    depth = 0
    visited = set(range(network.input_size))
    frontier = visited.copy()
    
    while frontier and depth < network.n_nodes:
        next_frontier = set()
        for node in frontier:
            for next_node in range(network.n_nodes):
                if weight_matrix[node, next_node] != 0 and next_node not in visited:
                    next_frontier.add(next_node)
                    visited.add(next_node)
        frontier = next_frontier
        if frontier:
            depth += 1
    
    stats['depth'] = depth
    
    return stats

def visualize_network_architecture(network: Network, save_path: Optional[str] = None):
    """Visualize the network architecture using networkx.
    
    Creates a layered visualization of the neural network with:
    - Input nodes in red (leftmost layer)
    - Hidden nodes in blue (middle layer)
    - Output nodes in green (rightmost layer)
    - Connections shown as arrows with thickness proportional to weight
    
    Args:
        network: Network instance to visualize
        save_path: Optional path to save the visualization
        
    Returns:
        matplotlib figure object or None if visualization fails
    """
    try:
        import networkx as nx
        import matplotlib.pyplot as plt
        
        genome = network.genome
        G = nx.DiGraph()
        
        # Calculate layout parameters
        n_inputs = len([node for node in genome.node_genes.values() if node.node_type == 'input'])
        n_outputs = len([node for node in genome.node_genes.values() if node.node_type == 'output'])
        hidden_nodes = [node.node_id for node in genome.node_genes.values() if node.node_type == 'hidden']
        n_hidden = len(hidden_nodes)
        
        # Layout parameters
        node_spacing = 1.0  # Vertical spacing between nodes in same layer
        layer_spacing = 2.0  # Horizontal spacing between layers
        
        # Initialize position and color dictionaries
        pos = {}
        node_colors = {}
        
        # Add input nodes (leftmost layer)
        input_start_y = -(n_inputs - 1) * node_spacing / 2  # Center vertically
        input_nodes = [node.node_id for node in genome.node_genes.values() if node.node_type == 'input']
        for i, node_idx in enumerate(input_nodes):
            pos[node_idx] = (0, input_start_y + i * node_spacing)
            node_colors[node_idx] = 'lightcoral'  # Light red for input nodes
        
        # Add hidden nodes (middle layer)
        if hidden_nodes:
            hidden_start_y = -(n_hidden - 1) * node_spacing / 2  # Center vertically
            for i, node_idx in enumerate(hidden_nodes):
                pos[node_idx] = (layer_spacing, hidden_start_y + i * node_spacing)
                node_colors[node_idx] = 'lightblue'  # Light blue for hidden nodes
                
        # Add output nodes (rightmost layer)
        output_start_y = -(n_outputs - 1) * node_spacing / 2  # Center vertically
        output_nodes = [node.node_id for node in genome.node_genes.values() if node.node_type == 'output']
        for i, node_idx in enumerate(output_nodes):
            pos[node_idx] = (2 * layer_spacing, output_start_y + i * node_spacing)
            node_colors[node_idx] = 'lightgreen'  # Light green for output nodes
            
        # Add bias node if present
        bias_node = [node.node_id for node in genome.node_genes.values() if node.node_type == 'bias']
        if bias_node:
            pos[bias_node[0]] = (0, input_start_y - node_spacing)  # Place below input nodes
            node_colors[bias_node[0]] = 'yellow'  # Yellow for bias node
        
        # Add all nodes to graph and ensure they have colors and positions
        for node_id in genome.node_genes:
            G.add_node(node_id)
            if node_id not in node_colors:  # Assign default color if not already assigned
                node_type = genome.node_genes[node_id].node_type
                if node_type == 'input':
                    node_colors[node_id] = 'lightcoral'
                elif node_type == 'hidden':
                    node_colors[node_id] = 'lightblue'
                elif node_type == 'output':
                    node_colors[node_id] = 'lightgreen'
                elif node_type == 'bias':
                    node_colors[node_id] = 'yellow'
                else:
                    node_colors[node_id] = 'gray'  # Default color for unknown types
            
            # Ensure node has a position
            if node_id not in pos:
                # Place unknown nodes in middle layer
                pos[node_id] = (layer_spacing, 0)
        
        # Add connections
        for conn in genome.connection_genes:
            if conn.enabled:
                # Scale connection width by weight
                width = abs(conn.weight) * 2.0
                # Use red for negative weights, green for positive
                color = 'red' if conn.weight < 0 else 'green'
                alpha = min(abs(conn.weight), 1.0)  # Transparency based on weight magnitude
                G.add_edge(conn.source, conn.target, weight=width, color=color, alpha=alpha)
        
        # Set up the plot
        fig = plt.figure(figsize=(12, 8))
        
        # Draw nodes with colors
        nx.draw_networkx_nodes(G, pos, node_color=[node_colors[node] for node in G.nodes()],
                             node_size=800, alpha=0.8)
        
        # Draw edges with width proportional to weight
        edge_weights = [G.get_edge_data(edge[0], edge[1])['weight'] for edge in G.edges()]
        if edge_weights:  # Only draw edges if there are any
            max_weight = max(edge_weights)
            normalized_weights = [3 * w / max_weight for w in edge_weights]  # Scale for visibility
            nx.draw_networkx_edges(G, pos, edge_color=[G.get_edge_data(edge[0], edge[1])['color'] for edge in G.edges()], 
                                width=normalized_weights,
                                alpha=[G.get_edge_data(edge[0], edge[1])['alpha'] for edge in G.edges()],
                                arrows=True, arrowsize=20)
        
        # Add node labels
        nx.draw_networkx_labels(G, pos, font_size=10)
        
        plt.title("Neural Network Architecture")
        plt.axis('off')  # Hide axes
        
        if save_path:
            # Ensure the directory exists
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            plt.savefig(save_path, bbox_inches='tight', dpi=300)
            plt.close(fig)  # Close the figure to free memory
        
        return fig
        
    except Exception as e:
        print(f"Error visualizing network: {str(e)}")
        return None

def plot_activation_distribution(population: List[Genome], save_path: Optional[str] = None):
    """Plot the distribution of node types in the population.
    
    Args:
        population: List of genomes in the population
        save_path: Optional path to save the plot
        
    Returns:
        matplotlib figure object or None if plotting fails
    """
    try:
        # Count nodes by type for each genome
        node_type_counts = defaultdict(int)
        for genome in population:
            node_type_counts['input'] += genome.input_size
            node_type_counts['hidden'] += len(genome.hidden_nodes)
            node_type_counts['output'] += genome.output_size
        
        if not node_type_counts:
            print("No nodes found in population")
            return None
        
        # Create bar plot
        fig = plt.figure(figsize=(10, 6))
        
        # Get node types and counts
        node_types = list(node_type_counts.keys())
        counts = list(node_type_counts.values())
        
        # Create bars with different colors
        colors = {'input': 'lightcoral', 'hidden': 'lightblue', 'output': 'lightgreen'}
        plt.bar(node_types, counts, color=[colors[t] for t in node_types], alpha=0.7)
        
        # Customize plot
        plt.title('Distribution of Node Types in Population')
        plt.xlabel('Node Type')
        plt.ylabel('Total Count')
        
        # Add count labels on top of bars
        for i, count in enumerate(counts):
            plt.text(i, count, str(count), ha='center', va='bottom')
        
        plt.tight_layout()
        
        # Save or display
        if save_path:
            # Ensure the directory exists
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            plt.savefig(save_path, bbox_inches='tight', dpi=300)
            plt.close(fig)  # Close the figure to free memory
            
        return fig
        
    except Exception as e:
        print(f"Error plotting activation distribution: {str(e)}")
        return None

def analyze_evolution_trends(stats: Dict, save_dir: str) -> None:
    """Analyze and plot evolution trends from training history.
    
    Args:
        stats: Dictionary containing training statistics
        save_dir: Directory to save plots
    """
    try:
        # Create plots directory if it doesn't exist
        os.makedirs(save_dir, exist_ok=True)
        
        # Check if we have any stats to plot
        if not stats or 'mean_fitness' not in stats or not stats['mean_fitness']:
            print("No evolution stats available yet")
            return
        
        # Extract metrics over generations
        generations = list(range(len(stats['mean_fitness'])))
        if not generations:  # No data points yet
            print("No generations completed yet")
            return
            
        metrics = {
            'Fitness': {
                'mean': stats.get('mean_fitness', []),
                'best': stats.get('best_fitness', [])
            },
            'Complexity': {
                'mean': stats.get('mean_complexity', []),
                'best': stats.get('best_complexity', [])
            }
        }
        
        # Plot each metric
        for metric_name, metric_data in metrics.items():
            # Verify we have data for this metric
            if not metric_data['mean'] or not metric_data['best']:
                print(f"No data available for {metric_name}")
                continue
                
            # Verify data lengths match
            if len(generations) != len(metric_data['mean']) or len(generations) != len(metric_data['best']):
                print(f"Data length mismatch for {metric_name}")
                continue
            
            fig = plt.figure(figsize=(10, 6))
            
            # Plot mean and best values
            plt.plot(generations, metric_data['mean'], label=f'Mean {metric_name}', alpha=0.7)
            plt.plot(generations, metric_data['best'], label=f'Best {metric_name}', alpha=0.7)
            
            plt.title(f'{metric_name} Over Generations')
            plt.xlabel('Generation')
            plt.ylabel(metric_name)
            plt.legend()
            plt.grid(True, alpha=0.3)
            
            # Save plot
            save_path = os.path.join(save_dir, f'{metric_name.lower()}_trends.png')
            plt.savefig(save_path, bbox_inches='tight', dpi=300)
            plt.close(fig)  # Close the figure to free memory
        
        # Plot species counts if available
        if 'n_species' in stats and stats['n_species']:
            n_species = stats['n_species']
            if len(generations) == len(n_species):  # Verify data length matches
                fig = plt.figure(figsize=(10, 6))
                plt.plot(generations, n_species, label='Number of Species', alpha=0.7)
                plt.title('Number of Species Over Generations')
                plt.xlabel('Generation')
                plt.ylabel('Number of Species')
                plt.legend()
                plt.grid(True, alpha=0.3)
                
                save_path = os.path.join(save_dir, 'species_trends.png')
                plt.savefig(save_path, bbox_inches='tight', dpi=300)
                plt.close(fig)  # Close the figure to free memory
            else:
                print("Species count data length mismatch")
                
    except Exception as e:
        print(f"Error analyzing evolution trends: {str(e)}")