"""Analysis utilities for neural networks. This module provides functions for analyzing neural network architectures, including complexity measures and structural properties. """ import numpy as np import networkx as nx import matplotlib.pyplot as plt from typing import Dict, Tuple, Union, Optional, List, Any from .network import Network from .genome import Genome from collections import defaultdict import os def analyze_network_complexity(network: Network) -> Dict[str, Any]: """Analyze the complexity of a neural network. Computes various complexity metrics including: 1. Number of nodes by type (input, hidden, output) 2. Number of connections 3. Network density 4. Activation functions used Args: network: Network instance to analyze Returns: Dictionary containing complexity metrics """ # Get network structure genome = network.genome # Count nodes by type n_input = genome.input_size n_hidden = len(genome.hidden_nodes) n_output = genome.output_size # Count connections n_connections = len(genome.connections) # Calculate connectivity density n_possible = (n_input + n_hidden + n_output) * (n_hidden + n_output) # No connections to input density = n_connections / n_possible if n_possible > 0 else 0 # Get activation functions (currently only ReLU) activation_functions = {'relu': n_hidden + n_output} # All nodes use ReLU return { 'n_input': n_input, 'n_hidden': n_hidden, 'n_output': n_output, 'n_connections': n_connections, 'density': density, 'activation_functions': activation_functions } def get_network_stats(network: Network) -> Dict[str, float]: """Get statistical measures of network properties. Computes various statistics about the network structure and parameters: - Number of nodes and connections - Average and std of weights and biases - Network density and depth Args: network: Network instance to analyze Returns: Dictionary containing network statistics """ stats = {} # Node counts stats['n_nodes'] = network.n_nodes stats['n_hidden'] = network.n_nodes - network.input_size - network.output_size # Connection stats weights = np.array(list(network.weights.values())) stats['n_connections'] = len(weights) stats['weight_mean'] = float(np.mean(weights)) stats['weight_std'] = float(np.std(weights)) # Bias stats biases = np.array(list(network.bias.values())) stats['n_biases'] = len(biases) stats['bias_mean'] = float(np.mean(biases)) stats['bias_std'] = float(np.std(biases)) # Connectivity n_possible = network.n_nodes * (network.n_nodes - 1) stats['density'] = len(weights) / n_possible if n_possible > 0 else 0 # Compute approximate network depth weight_matrix = network.weight_matrix depth = 0 visited = set(range(network.input_size)) frontier = visited.copy() while frontier and depth < network.n_nodes: next_frontier = set() for node in frontier: for next_node in range(network.n_nodes): if weight_matrix[node, next_node] != 0 and next_node not in visited: next_frontier.add(next_node) visited.add(next_node) frontier = next_frontier if frontier: depth += 1 stats['depth'] = depth return stats def visualize_network_architecture(network: Network, save_path: Optional[str] = None): """Visualize the network architecture using networkx. Creates a layered visualization of the neural network with: - Input nodes in red (leftmost layer) - Hidden nodes in blue (middle layer) - Output nodes in green (rightmost layer) - Connections shown as arrows with thickness proportional to weight Args: network: Network instance to visualize save_path: Optional path to save the visualization Returns: matplotlib figure object or None if visualization fails """ try: import networkx as nx import matplotlib.pyplot as plt genome = network.genome G = nx.DiGraph() # Calculate layout parameters n_inputs = len([node for node in genome.node_genes.values() if node.node_type == 'input']) n_outputs = len([node for node in genome.node_genes.values() if node.node_type == 'output']) hidden_nodes = [node.node_id for node in genome.node_genes.values() if node.node_type == 'hidden'] n_hidden = len(hidden_nodes) # Layout parameters node_spacing = 1.0 # Vertical spacing between nodes in same layer layer_spacing = 2.0 # Horizontal spacing between layers # Initialize position and color dictionaries pos = {} node_colors = {} # Add input nodes (leftmost layer) input_start_y = -(n_inputs - 1) * node_spacing / 2 # Center vertically input_nodes = [node.node_id for node in genome.node_genes.values() if node.node_type == 'input'] for i, node_idx in enumerate(input_nodes): pos[node_idx] = (0, input_start_y + i * node_spacing) node_colors[node_idx] = 'lightcoral' # Light red for input nodes # Add hidden nodes (middle layer) if hidden_nodes: hidden_start_y = -(n_hidden - 1) * node_spacing / 2 # Center vertically for i, node_idx in enumerate(hidden_nodes): pos[node_idx] = (layer_spacing, hidden_start_y + i * node_spacing) node_colors[node_idx] = 'lightblue' # Light blue for hidden nodes # Add output nodes (rightmost layer) output_start_y = -(n_outputs - 1) * node_spacing / 2 # Center vertically output_nodes = [node.node_id for node in genome.node_genes.values() if node.node_type == 'output'] for i, node_idx in enumerate(output_nodes): pos[node_idx] = (2 * layer_spacing, output_start_y + i * node_spacing) node_colors[node_idx] = 'lightgreen' # Light green for output nodes # Add bias node if present bias_node = [node.node_id for node in genome.node_genes.values() if node.node_type == 'bias'] if bias_node: pos[bias_node[0]] = (0, input_start_y - node_spacing) # Place below input nodes node_colors[bias_node[0]] = 'yellow' # Yellow for bias node # Add all nodes to graph and ensure they have colors and positions for node_id in genome.node_genes: G.add_node(node_id) if node_id not in node_colors: # Assign default color if not already assigned node_type = genome.node_genes[node_id].node_type if node_type == 'input': node_colors[node_id] = 'lightcoral' elif node_type == 'hidden': node_colors[node_id] = 'lightblue' elif node_type == 'output': node_colors[node_id] = 'lightgreen' elif node_type == 'bias': node_colors[node_id] = 'yellow' else: node_colors[node_id] = 'gray' # Default color for unknown types # Ensure node has a position if node_id not in pos: # Place unknown nodes in middle layer pos[node_id] = (layer_spacing, 0) # Add connections for conn in genome.connection_genes: if conn.enabled: # Scale connection width by weight width = abs(conn.weight) * 2.0 # Use red for negative weights, green for positive color = 'red' if conn.weight < 0 else 'green' alpha = min(abs(conn.weight), 1.0) # Transparency based on weight magnitude G.add_edge(conn.source, conn.target, weight=width, color=color, alpha=alpha) # Set up the plot fig = plt.figure(figsize=(12, 8)) # Draw nodes with colors nx.draw_networkx_nodes(G, pos, node_color=[node_colors[node] for node in G.nodes()], node_size=800, alpha=0.8) # Draw edges with width proportional to weight edge_weights = [G.get_edge_data(edge[0], edge[1])['weight'] for edge in G.edges()] if edge_weights: # Only draw edges if there are any max_weight = max(edge_weights) normalized_weights = [3 * w / max_weight for w in edge_weights] # Scale for visibility nx.draw_networkx_edges(G, pos, edge_color=[G.get_edge_data(edge[0], edge[1])['color'] for edge in G.edges()], width=normalized_weights, alpha=[G.get_edge_data(edge[0], edge[1])['alpha'] for edge in G.edges()], arrows=True, arrowsize=20) # Add node labels nx.draw_networkx_labels(G, pos, font_size=10) plt.title("Neural Network Architecture") plt.axis('off') # Hide axes if save_path: # Ensure the directory exists os.makedirs(os.path.dirname(save_path), exist_ok=True) plt.savefig(save_path, bbox_inches='tight', dpi=300) plt.close(fig) # Close the figure to free memory return fig except Exception as e: print(f"Error visualizing network: {str(e)}") return None def plot_activation_distribution(population: List[Genome], save_path: Optional[str] = None): """Plot the distribution of node types in the population. Args: population: List of genomes in the population save_path: Optional path to save the plot Returns: matplotlib figure object or None if plotting fails """ try: # Count nodes by type for each genome node_type_counts = defaultdict(int) for genome in population: node_type_counts['input'] += genome.input_size node_type_counts['hidden'] += len(genome.hidden_nodes) node_type_counts['output'] += genome.output_size if not node_type_counts: print("No nodes found in population") return None # Create bar plot fig = plt.figure(figsize=(10, 6)) # Get node types and counts node_types = list(node_type_counts.keys()) counts = list(node_type_counts.values()) # Create bars with different colors colors = {'input': 'lightcoral', 'hidden': 'lightblue', 'output': 'lightgreen'} plt.bar(node_types, counts, color=[colors[t] for t in node_types], alpha=0.7) # Customize plot plt.title('Distribution of Node Types in Population') plt.xlabel('Node Type') plt.ylabel('Total Count') # Add count labels on top of bars for i, count in enumerate(counts): plt.text(i, count, str(count), ha='center', va='bottom') plt.tight_layout() # Save or display if save_path: # Ensure the directory exists os.makedirs(os.path.dirname(save_path), exist_ok=True) plt.savefig(save_path, bbox_inches='tight', dpi=300) plt.close(fig) # Close the figure to free memory return fig except Exception as e: print(f"Error plotting activation distribution: {str(e)}") return None def analyze_evolution_trends(stats: Dict, save_dir: str) -> None: """Analyze and plot evolution trends from training history. Args: stats: Dictionary containing training statistics save_dir: Directory to save plots """ try: # Create plots directory if it doesn't exist os.makedirs(save_dir, exist_ok=True) # Check if we have any stats to plot if not stats or 'mean_fitness' not in stats or not stats['mean_fitness']: print("No evolution stats available yet") return # Extract metrics over generations generations = list(range(len(stats['mean_fitness']))) if not generations: # No data points yet print("No generations completed yet") return metrics = { 'Fitness': { 'mean': stats.get('mean_fitness', []), 'best': stats.get('best_fitness', []) }, 'Complexity': { 'mean': stats.get('mean_complexity', []), 'best': stats.get('best_complexity', []) } } # Plot each metric for metric_name, metric_data in metrics.items(): # Verify we have data for this metric if not metric_data['mean'] or not metric_data['best']: print(f"No data available for {metric_name}") continue # Verify data lengths match if len(generations) != len(metric_data['mean']) or len(generations) != len(metric_data['best']): print(f"Data length mismatch for {metric_name}") continue fig = plt.figure(figsize=(10, 6)) # Plot mean and best values plt.plot(generations, metric_data['mean'], label=f'Mean {metric_name}', alpha=0.7) plt.plot(generations, metric_data['best'], label=f'Best {metric_name}', alpha=0.7) plt.title(f'{metric_name} Over Generations') plt.xlabel('Generation') plt.ylabel(metric_name) plt.legend() plt.grid(True, alpha=0.3) # Save plot save_path = os.path.join(save_dir, f'{metric_name.lower()}_trends.png') plt.savefig(save_path, bbox_inches='tight', dpi=300) plt.close(fig) # Close the figure to free memory # Plot species counts if available if 'n_species' in stats and stats['n_species']: n_species = stats['n_species'] if len(generations) == len(n_species): # Verify data length matches fig = plt.figure(figsize=(10, 6)) plt.plot(generations, n_species, label='Number of Species', alpha=0.7) plt.title('Number of Species Over Generations') plt.xlabel('Generation') plt.ylabel('Number of Species') plt.legend() plt.grid(True, alpha=0.3) save_path = os.path.join(save_dir, 'species_trends.png') plt.savefig(save_path, bbox_inches='tight', dpi=300) plt.close(fig) # Close the figure to free memory else: print("Species count data length mismatch") except Exception as e: print(f"Error analyzing evolution trends: {str(e)}")