|
"""Analysis utilities for neural networks. |
|
|
|
This module provides functions for analyzing neural network architectures, |
|
including complexity measures and structural properties. |
|
""" |
|
|
|
import numpy as np |
|
import networkx as nx |
|
import matplotlib.pyplot as plt |
|
from typing import Dict, Tuple, Union, Optional, List, Any |
|
from .network import Network |
|
from .genome import Genome |
|
from collections import defaultdict |
|
import os |
|
|
|
def analyze_network_complexity(network: Network) -> Dict[str, Any]: |
|
"""Analyze the complexity of a neural network. |
|
|
|
Computes various complexity metrics including: |
|
1. Number of nodes by type (input, hidden, output) |
|
2. Number of connections |
|
3. Network density |
|
4. Activation functions used |
|
|
|
Args: |
|
network: Network instance to analyze |
|
|
|
Returns: |
|
Dictionary containing complexity metrics |
|
""" |
|
|
|
genome = network.genome |
|
|
|
|
|
n_input = genome.input_size |
|
n_hidden = len(genome.hidden_nodes) |
|
n_output = genome.output_size |
|
|
|
|
|
n_connections = len(genome.connections) |
|
|
|
|
|
n_possible = (n_input + n_hidden + n_output) * (n_hidden + n_output) |
|
density = n_connections / n_possible if n_possible > 0 else 0 |
|
|
|
|
|
activation_functions = {'relu': n_hidden + n_output} |
|
|
|
return { |
|
'n_input': n_input, |
|
'n_hidden': n_hidden, |
|
'n_output': n_output, |
|
'n_connections': n_connections, |
|
'density': density, |
|
'activation_functions': activation_functions |
|
} |
|
|
|
def get_network_stats(network: Network) -> Dict[str, float]: |
|
"""Get statistical measures of network properties. |
|
|
|
Computes various statistics about the network structure and parameters: |
|
- Number of nodes and connections |
|
- Average and std of weights and biases |
|
- Network density and depth |
|
|
|
Args: |
|
network: Network instance to analyze |
|
|
|
Returns: |
|
Dictionary containing network statistics |
|
""" |
|
stats = {} |
|
|
|
|
|
stats['n_nodes'] = network.n_nodes |
|
stats['n_hidden'] = network.n_nodes - network.input_size - network.output_size |
|
|
|
|
|
weights = np.array(list(network.weights.values())) |
|
stats['n_connections'] = len(weights) |
|
stats['weight_mean'] = float(np.mean(weights)) |
|
stats['weight_std'] = float(np.std(weights)) |
|
|
|
|
|
biases = np.array(list(network.bias.values())) |
|
stats['n_biases'] = len(biases) |
|
stats['bias_mean'] = float(np.mean(biases)) |
|
stats['bias_std'] = float(np.std(biases)) |
|
|
|
|
|
n_possible = network.n_nodes * (network.n_nodes - 1) |
|
stats['density'] = len(weights) / n_possible if n_possible > 0 else 0 |
|
|
|
|
|
weight_matrix = network.weight_matrix |
|
depth = 0 |
|
visited = set(range(network.input_size)) |
|
frontier = visited.copy() |
|
|
|
while frontier and depth < network.n_nodes: |
|
next_frontier = set() |
|
for node in frontier: |
|
for next_node in range(network.n_nodes): |
|
if weight_matrix[node, next_node] != 0 and next_node not in visited: |
|
next_frontier.add(next_node) |
|
visited.add(next_node) |
|
frontier = next_frontier |
|
if frontier: |
|
depth += 1 |
|
|
|
stats['depth'] = depth |
|
|
|
return stats |
|
|
|
def visualize_network_architecture(network: Network, save_path: Optional[str] = None): |
|
"""Visualize the network architecture using networkx. |
|
|
|
Creates a layered visualization of the neural network with: |
|
- Input nodes in red (leftmost layer) |
|
- Hidden nodes in blue (middle layer) |
|
- Output nodes in green (rightmost layer) |
|
- Connections shown as arrows with thickness proportional to weight |
|
|
|
Args: |
|
network: Network instance to visualize |
|
save_path: Optional path to save the visualization |
|
|
|
Returns: |
|
matplotlib figure object or None if visualization fails |
|
""" |
|
try: |
|
import networkx as nx |
|
import matplotlib.pyplot as plt |
|
|
|
genome = network.genome |
|
G = nx.DiGraph() |
|
|
|
|
|
n_inputs = len([node for node in genome.node_genes.values() if node.node_type == 'input']) |
|
n_outputs = len([node for node in genome.node_genes.values() if node.node_type == 'output']) |
|
hidden_nodes = [node.node_id for node in genome.node_genes.values() if node.node_type == 'hidden'] |
|
n_hidden = len(hidden_nodes) |
|
|
|
|
|
node_spacing = 1.0 |
|
layer_spacing = 2.0 |
|
|
|
|
|
pos = {} |
|
node_colors = {} |
|
|
|
|
|
input_start_y = -(n_inputs - 1) * node_spacing / 2 |
|
input_nodes = [node.node_id for node in genome.node_genes.values() if node.node_type == 'input'] |
|
for i, node_idx in enumerate(input_nodes): |
|
pos[node_idx] = (0, input_start_y + i * node_spacing) |
|
node_colors[node_idx] = 'lightcoral' |
|
|
|
|
|
if hidden_nodes: |
|
hidden_start_y = -(n_hidden - 1) * node_spacing / 2 |
|
for i, node_idx in enumerate(hidden_nodes): |
|
pos[node_idx] = (layer_spacing, hidden_start_y + i * node_spacing) |
|
node_colors[node_idx] = 'lightblue' |
|
|
|
|
|
output_start_y = -(n_outputs - 1) * node_spacing / 2 |
|
output_nodes = [node.node_id for node in genome.node_genes.values() if node.node_type == 'output'] |
|
for i, node_idx in enumerate(output_nodes): |
|
pos[node_idx] = (2 * layer_spacing, output_start_y + i * node_spacing) |
|
node_colors[node_idx] = 'lightgreen' |
|
|
|
|
|
bias_node = [node.node_id for node in genome.node_genes.values() if node.node_type == 'bias'] |
|
if bias_node: |
|
pos[bias_node[0]] = (0, input_start_y - node_spacing) |
|
node_colors[bias_node[0]] = 'yellow' |
|
|
|
|
|
for node_id in genome.node_genes: |
|
G.add_node(node_id) |
|
if node_id not in node_colors: |
|
node_type = genome.node_genes[node_id].node_type |
|
if node_type == 'input': |
|
node_colors[node_id] = 'lightcoral' |
|
elif node_type == 'hidden': |
|
node_colors[node_id] = 'lightblue' |
|
elif node_type == 'output': |
|
node_colors[node_id] = 'lightgreen' |
|
elif node_type == 'bias': |
|
node_colors[node_id] = 'yellow' |
|
else: |
|
node_colors[node_id] = 'gray' |
|
|
|
|
|
if node_id not in pos: |
|
|
|
pos[node_id] = (layer_spacing, 0) |
|
|
|
|
|
for conn in genome.connection_genes: |
|
if conn.enabled: |
|
|
|
width = abs(conn.weight) * 2.0 |
|
|
|
color = 'red' if conn.weight < 0 else 'green' |
|
alpha = min(abs(conn.weight), 1.0) |
|
G.add_edge(conn.source, conn.target, weight=width, color=color, alpha=alpha) |
|
|
|
|
|
fig = plt.figure(figsize=(12, 8)) |
|
|
|
|
|
nx.draw_networkx_nodes(G, pos, node_color=[node_colors[node] for node in G.nodes()], |
|
node_size=800, alpha=0.8) |
|
|
|
|
|
edge_weights = [G.get_edge_data(edge[0], edge[1])['weight'] for edge in G.edges()] |
|
if edge_weights: |
|
max_weight = max(edge_weights) |
|
normalized_weights = [3 * w / max_weight for w in edge_weights] |
|
nx.draw_networkx_edges(G, pos, edge_color=[G.get_edge_data(edge[0], edge[1])['color'] for edge in G.edges()], |
|
width=normalized_weights, |
|
alpha=[G.get_edge_data(edge[0], edge[1])['alpha'] for edge in G.edges()], |
|
arrows=True, arrowsize=20) |
|
|
|
|
|
nx.draw_networkx_labels(G, pos, font_size=10) |
|
|
|
plt.title("Neural Network Architecture") |
|
plt.axis('off') |
|
|
|
if save_path: |
|
|
|
os.makedirs(os.path.dirname(save_path), exist_ok=True) |
|
plt.savefig(save_path, bbox_inches='tight', dpi=300) |
|
plt.close(fig) |
|
|
|
return fig |
|
|
|
except Exception as e: |
|
print(f"Error visualizing network: {str(e)}") |
|
return None |
|
|
|
def plot_activation_distribution(population: List[Genome], save_path: Optional[str] = None): |
|
"""Plot the distribution of node types in the population. |
|
|
|
Args: |
|
population: List of genomes in the population |
|
save_path: Optional path to save the plot |
|
|
|
Returns: |
|
matplotlib figure object or None if plotting fails |
|
""" |
|
try: |
|
|
|
node_type_counts = defaultdict(int) |
|
for genome in population: |
|
node_type_counts['input'] += genome.input_size |
|
node_type_counts['hidden'] += len(genome.hidden_nodes) |
|
node_type_counts['output'] += genome.output_size |
|
|
|
if not node_type_counts: |
|
print("No nodes found in population") |
|
return None |
|
|
|
|
|
fig = plt.figure(figsize=(10, 6)) |
|
|
|
|
|
node_types = list(node_type_counts.keys()) |
|
counts = list(node_type_counts.values()) |
|
|
|
|
|
colors = {'input': 'lightcoral', 'hidden': 'lightblue', 'output': 'lightgreen'} |
|
plt.bar(node_types, counts, color=[colors[t] for t in node_types], alpha=0.7) |
|
|
|
|
|
plt.title('Distribution of Node Types in Population') |
|
plt.xlabel('Node Type') |
|
plt.ylabel('Total Count') |
|
|
|
|
|
for i, count in enumerate(counts): |
|
plt.text(i, count, str(count), ha='center', va='bottom') |
|
|
|
plt.tight_layout() |
|
|
|
|
|
if save_path: |
|
|
|
os.makedirs(os.path.dirname(save_path), exist_ok=True) |
|
plt.savefig(save_path, bbox_inches='tight', dpi=300) |
|
plt.close(fig) |
|
|
|
return fig |
|
|
|
except Exception as e: |
|
print(f"Error plotting activation distribution: {str(e)}") |
|
return None |
|
|
|
def analyze_evolution_trends(stats: Dict, save_dir: str) -> None: |
|
"""Analyze and plot evolution trends from training history. |
|
|
|
Args: |
|
stats: Dictionary containing training statistics |
|
save_dir: Directory to save plots |
|
""" |
|
try: |
|
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
if not stats or 'mean_fitness' not in stats or not stats['mean_fitness']: |
|
print("No evolution stats available yet") |
|
return |
|
|
|
|
|
generations = list(range(len(stats['mean_fitness']))) |
|
if not generations: |
|
print("No generations completed yet") |
|
return |
|
|
|
metrics = { |
|
'Fitness': { |
|
'mean': stats.get('mean_fitness', []), |
|
'best': stats.get('best_fitness', []) |
|
}, |
|
'Complexity': { |
|
'mean': stats.get('mean_complexity', []), |
|
'best': stats.get('best_complexity', []) |
|
} |
|
} |
|
|
|
|
|
for metric_name, metric_data in metrics.items(): |
|
|
|
if not metric_data['mean'] or not metric_data['best']: |
|
print(f"No data available for {metric_name}") |
|
continue |
|
|
|
|
|
if len(generations) != len(metric_data['mean']) or len(generations) != len(metric_data['best']): |
|
print(f"Data length mismatch for {metric_name}") |
|
continue |
|
|
|
fig = plt.figure(figsize=(10, 6)) |
|
|
|
|
|
plt.plot(generations, metric_data['mean'], label=f'Mean {metric_name}', alpha=0.7) |
|
plt.plot(generations, metric_data['best'], label=f'Best {metric_name}', alpha=0.7) |
|
|
|
plt.title(f'{metric_name} Over Generations') |
|
plt.xlabel('Generation') |
|
plt.ylabel(metric_name) |
|
plt.legend() |
|
plt.grid(True, alpha=0.3) |
|
|
|
|
|
save_path = os.path.join(save_dir, f'{metric_name.lower()}_trends.png') |
|
plt.savefig(save_path, bbox_inches='tight', dpi=300) |
|
plt.close(fig) |
|
|
|
|
|
if 'n_species' in stats and stats['n_species']: |
|
n_species = stats['n_species'] |
|
if len(generations) == len(n_species): |
|
fig = plt.figure(figsize=(10, 6)) |
|
plt.plot(generations, n_species, label='Number of Species', alpha=0.7) |
|
plt.title('Number of Species Over Generations') |
|
plt.xlabel('Generation') |
|
plt.ylabel('Number of Species') |
|
plt.legend() |
|
plt.grid(True, alpha=0.3) |
|
|
|
save_path = os.path.join(save_dir, 'species_trends.png') |
|
plt.savefig(save_path, bbox_inches='tight', dpi=300) |
|
plt.close(fig) |
|
else: |
|
print("Species count data length mismatch") |
|
|
|
except Exception as e: |
|
print(f"Error analyzing evolution trends: {str(e)}") |
|
|