"""Analysis utilities for neural networks. |
This module provides functions for analyzing neural network architectures, |
including complexity measures and structural properties. |
""" |
import numpy as np |
import networkx as nx |
import matplotlib.pyplot as plt |
from typing import Dict, Tuple, Union, Optional, List, Any |
from .network import Network |
from .genome import Genome |
from collections import defaultdict |
import os |
def analyze_network_complexity(network: Network) -> Dict[str, Any]: |
"""Analyze the complexity of a neural network. |
Computes various complexity metrics including: |
1. Number of nodes by type (input, hidden, output) |
2. Number of connections |
3. Network density |
4. Activation functions used |
Args: |
network: Network instance to analyze |
Returns: |
Dictionary containing complexity metrics |
""" |
genome = network.genome |
n_input = genome.input_size |
n_hidden = len(genome.hidden_nodes) |
n_output = genome.output_size |
n_connections = len(genome.connections) |
n_possible = (n_input + n_hidden + n_output) * (n_hidden + n_output) |
density = n_connections / n_possible if n_possible > 0 else 0 |
activation_functions = {'relu': n_hidden + n_output} |
return { |
'n_input': n_input, |
'n_hidden': n_hidden, |
'n_output': n_output, |
'n_connections': n_connections, |
'density': density, |
'activation_functions': activation_functions |
} |
def get_network_stats(network: Network) -> Dict[str, float]: |
"""Get statistical measures of network properties. |
Computes various statistics about the network structure and parameters: |
- Number of nodes and connections |
- Average and std of weights and biases |
- Network density and depth |
Args: |
network: Network instance to analyze |
Returns: |
Dictionary containing network statistics |
""" |
stats = {} |
stats['n_nodes'] = network.n_nodes |
stats['n_hidden'] = network.n_nodes - network.input_size - network.output_size |
weights = np.array(list(network.weights.values())) |
stats['n_connections'] = len(weights) |
stats['weight_mean'] = float(np.mean(weights)) |
stats['weight_std'] = float(np.std(weights)) |
biases = np.array(list(network.bias.values())) |
stats['n_biases'] = len(biases) |
stats['bias_mean'] = float(np.mean(biases)) |
stats['bias_std'] = float(np.std(biases)) |
n_possible = network.n_nodes * (network.n_nodes - 1) |
stats['density'] = len(weights) / n_possible if n_possible > 0 else 0 |
weight_matrix = network.weight_matrix |
depth = 0 |
visited = set(range(network.input_size)) |
frontier = visited.copy() |
while frontier and depth < network.n_nodes: |
next_frontier = set() |
for node in frontier: |
for next_node in range(network.n_nodes): |
if weight_matrix[node, next_node] != 0 and next_node not in visited: |
next_frontier.add(next_node) |
visited.add(next_node) |
frontier = next_frontier |
if frontier: |
depth += 1 |
stats['depth'] = depth |
return stats |
def visualize_network_architecture(network: Network, save_path: Optional[str] = None): |
"""Visualize the network architecture using networkx. |
Creates a layered visualization of the neural network with: |
- Input nodes in red (leftmost layer) |
- Hidden nodes in blue (middle layer) |
- Output nodes in green (rightmost layer) |
- Connections shown as arrows with thickness proportional to weight |
Args: |
network: Network instance to visualize |
save_path: Optional path to save the visualization |
Returns: |
matplotlib figure object or None if visualization fails |
""" |
try: |
import networkx as nx |
import matplotlib.pyplot as plt |
genome = network.genome |
G = nx.DiGraph() |
n_inputs = len([node for node in genome.node_genes.values() if node.node_type == 'input']) |
n_outputs = len([node for node in genome.node_genes.values() if node.node_type == 'output']) |
hidden_nodes = [node.node_id for node in genome.node_genes.values() if node.node_type == 'hidden'] |
n_hidden = len(hidden_nodes) |
node_spacing = 1.0 |
layer_spacing = 2.0 |
pos = {} |
node_colors = {} |
input_start_y = -(n_inputs - 1) * node_spacing / 2 |
input_nodes = [node.node_id for node in genome.node_genes.values() if node.node_type == 'input'] |
for i, node_idx in enumerate(input_nodes): |
pos[node_idx] = (0, input_start_y + i * node_spacing) |
node_colors[node_idx] = 'lightcoral' |
if hidden_nodes: |
hidden_start_y = -(n_hidden - 1) * node_spacing / 2 |
for i, node_idx in enumerate(hidden_nodes): |
pos[node_idx] = (layer_spacing, hidden_start_y + i * node_spacing) |
node_colors[node_idx] = 'lightblue' |
output_start_y = -(n_outputs - 1) * node_spacing / 2 |
output_nodes = [node.node_id for node in genome.node_genes.values() if node.node_type == 'output'] |
for i, node_idx in enumerate(output_nodes): |
pos[node_idx] = (2 * layer_spacing, output_start_y + i * node_spacing) |
node_colors[node_idx] = 'lightgreen' |
bias_node = [node.node_id for node in genome.node_genes.values() if node.node_type == 'bias'] |
if bias_node: |
pos[bias_node[0]] = (0, input_start_y - node_spacing) |
node_colors[bias_node[0]] = 'yellow' |
for node_id in genome.node_genes: |
G.add_node(node_id) |
if node_id not in node_colors: |
node_type = genome.node_genes[node_id].node_type |
if node_type == 'input': |
node_colors[node_id] = 'lightcoral' |
elif node_type == 'hidden': |
node_colors[node_id] = 'lightblue' |
elif node_type == 'output': |
node_colors[node_id] = 'lightgreen' |
elif node_type == 'bias': |
node_colors[node_id] = 'yellow' |
else: |
node_colors[node_id] = 'gray' |
if node_id not in pos: |
pos[node_id] = (layer_spacing, 0) |
for conn in genome.connection_genes: |
if conn.enabled: |
width = abs(conn.weight) * 2.0 |
color = 'red' if conn.weight < 0 else 'green' |
alpha = min(abs(conn.weight), 1.0) |
G.add_edge(conn.source, conn.target, weight=width, color=color, alpha=alpha) |
fig = plt.figure(figsize=(12, 8)) |
nx.draw_networkx_nodes(G, pos, node_color=[node_colors[node] for node in G.nodes()], |
node_size=800, alpha=0.8) |
edge_weights = [G.get_edge_data(edge[0], edge[1])['weight'] for edge in G.edges()] |
if edge_weights: |
max_weight = max(edge_weights) |
normalized_weights = [3 * w / max_weight for w in edge_weights] |
nx.draw_networkx_edges(G, pos, edge_color=[G.get_edge_data(edge[0], edge[1])['color'] for edge in G.edges()], |
width=normalized_weights, |
alpha=[G.get_edge_data(edge[0], edge[1])['alpha'] for edge in G.edges()], |
arrows=True, arrowsize=20) |
nx.draw_networkx_labels(G, pos, font_size=10) |
plt.title("Neural Network Architecture") |
plt.axis('off') |
if save_path: |
os.makedirs(os.path.dirname(save_path), exist_ok=True) |
plt.savefig(save_path, bbox_inches='tight', dpi=300) |
plt.close(fig) |
return fig |
except Exception as e: |
print(f"Error visualizing network: {str(e)}") |
return None |
def plot_activation_distribution(population: List[Genome], save_path: Optional[str] = None): |
"""Plot the distribution of node types in the population. |
Args: |
population: List of genomes in the population |
save_path: Optional path to save the plot |
Returns: |
matplotlib figure object or None if plotting fails |
""" |
try: |
node_type_counts = defaultdict(int) |
for genome in population: |
node_type_counts['input'] += genome.input_size |
node_type_counts['hidden'] += len(genome.hidden_nodes) |
node_type_counts['output'] += genome.output_size |
if not node_type_counts: |
print("No nodes found in population") |
return None |
fig = plt.figure(figsize=(10, 6)) |
node_types = list(node_type_counts.keys()) |
counts = list(node_type_counts.values()) |
colors = {'input': 'lightcoral', 'hidden': 'lightblue', 'output': 'lightgreen'} |
plt.bar(node_types, counts, color=[colors[t] for t in node_types], alpha=0.7) |
plt.title('Distribution of Node Types in Population') |
plt.xlabel('Node Type') |
plt.ylabel('Total Count') |
for i, count in enumerate(counts): |
plt.text(i, count, str(count), ha='center', va='bottom') |
plt.tight_layout() |
if save_path: |
os.makedirs(os.path.dirname(save_path), exist_ok=True) |
plt.savefig(save_path, bbox_inches='tight', dpi=300) |
plt.close(fig) |
return fig |
except Exception as e: |
print(f"Error plotting activation distribution: {str(e)}") |
return None |
def analyze_evolution_trends(stats: Dict, save_dir: str) -> None: |
"""Analyze and plot evolution trends from training history. |
Args: |
stats: Dictionary containing training statistics |
save_dir: Directory to save plots |
""" |
try: |
os.makedirs(save_dir, exist_ok=True) |
if not stats or 'mean_fitness' not in stats or not stats['mean_fitness']: |
print("No evolution stats available yet") |
return |
generations = list(range(len(stats['mean_fitness']))) |
if not generations: |
print("No generations completed yet") |
return |
metrics = { |
'Fitness': { |
'mean': stats.get('mean_fitness', []), |
'best': stats.get('best_fitness', []) |
}, |
'Complexity': { |
'mean': stats.get('mean_complexity', []), |
'best': stats.get('best_complexity', []) |
} |
} |
for metric_name, metric_data in metrics.items(): |
if not metric_data['mean'] or not metric_data['best']: |
print(f"No data available for {metric_name}") |
continue |
if len(generations) != len(metric_data['mean']) or len(generations) != len(metric_data['best']): |
print(f"Data length mismatch for {metric_name}") |
continue |
fig = plt.figure(figsize=(10, 6)) |
plt.plot(generations, metric_data['mean'], label=f'Mean {metric_name}', alpha=0.7) |
plt.plot(generations, metric_data['best'], label=f'Best {metric_name}', alpha=0.7) |
plt.title(f'{metric_name} Over Generations') |
plt.xlabel('Generation') |
plt.ylabel(metric_name) |
plt.legend() |
plt.grid(True, alpha=0.3) |
save_path = os.path.join(save_dir, f'{metric_name.lower()}_trends.png') |
plt.savefig(save_path, bbox_inches='tight', dpi=300) |
plt.close(fig) |
if 'n_species' in stats and stats['n_species']: |
n_species = stats['n_species'] |
if len(generations) == len(n_species): |
fig = plt.figure(figsize=(10, 6)) |
plt.plot(generations, n_species, label='Number of Species', alpha=0.7) |
plt.title('Number of Species Over Generations') |
plt.xlabel('Generation') |
plt.ylabel('Number of Species') |
plt.legend() |
plt.grid(True, alpha=0.3) |
save_path = os.path.join(save_dir, 'species_trends.png') |
plt.savefig(save_path, bbox_inches='tight', dpi=300) |
plt.close(fig) |
else: |
print("Species count data length mismatch") |
except Exception as e: |
print(f"Error analyzing evolution trends: {str(e)}") |