File size: 5,492 Bytes
0232428 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
"""Visualization utilities for NEAT networks."""
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import jax.numpy as jnp
from typing import List, Dict, Any
from .network import Network
def plot_network_structure(network: Network, title: str = "Network Structure",
save_path: str = None, show: bool = True) -> None:
"""Plot network structure using networkx.
Args:
network: Network to visualize
title: Plot title
save_path: Path to save plot to
show: Whether to display plot
"""
# Create graph
G = nx.DiGraph()
# Add nodes
input_nodes = network.get_input_nodes()
hidden_nodes = network.get_hidden_nodes()
output_nodes = network.get_output_nodes()
# Position nodes in layers
pos = {}
# Input layer
for i, node in enumerate(input_nodes):
G.add_node(node, layer='input')
pos[node] = (0, (i - len(input_nodes)/2) / max(1, len(input_nodes)-1))
# Hidden layer
for i, node in enumerate(hidden_nodes):
G.add_node(node, layer='hidden')
pos[node] = (1, (i - len(hidden_nodes)/2) / max(1, len(hidden_nodes)-1))
# Output layer
for i, node in enumerate(output_nodes):
G.add_node(node, layer='output')
pos[node] = (2, (i - len(output_nodes)/2) / max(1, len(output_nodes)-1))
# Add edges with weights
connections = network.get_connections()
for src, dst, weight in connections:
# Convert JAX array to NumPy float
if isinstance(weight, jnp.ndarray):
weight = float(weight)
G.add_edge(src, dst, weight=weight)
# Draw network
plt.figure(figsize=(8, 6))
# Draw nodes
node_colors = ['lightblue' if G.nodes[n]['layer'] == 'input' else
'lightgreen' if G.nodes[n]['layer'] == 'hidden' else
'salmon' for n in G.nodes()]
nx.draw_networkx_nodes(G, pos, node_color=node_colors)
# Draw edges with weights as colors
edges = G.edges()
weights = [G[u][v]['weight'] for u, v in edges]
# Normalize weights for coloring
max_weight = max(abs(min(weights)), abs(max(weights)))
if max_weight > 0:
norm_weights = [(w + max_weight)/(2*max_weight) for w in weights]
else:
norm_weights = [0.5] * len(weights) # Default to middle color if all weights are 0
nx.draw_networkx_edges(G, pos, edge_color=norm_weights,
edge_cmap=plt.cm.RdYlBu, width=2)
# Add labels
labels = {n: str(n) for n in G.nodes()}
nx.draw_networkx_labels(G, pos, labels)
plt.title(title)
plt.axis('off')
if save_path:
plt.savefig(save_path)
if show:
plt.show()
else:
plt.close()
def plot_decision_boundary(network: Network, X: np.ndarray, y: np.ndarray,
title: str = "Decision Boundary",
save_path: str = None, show: bool = True) -> None:
"""Plot decision boundary for 2D classification problem.
Args:
network: Trained network
X: Input data (n_samples, 2)
y: Labels
title: Plot title
save_path: Path to save plot to
show: Whether to display plot
"""
# Convert JAX arrays to NumPy
if isinstance(X, jnp.ndarray):
X = np.array(X)
if isinstance(y, jnp.ndarray):
y = np.array(y)
# Create mesh grid
h = 0.02 # Step size
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
np.arange(y_min, y_max, h))
# Make predictions on mesh
mesh_points = np.c_[xx.ravel(), yy.ravel()]
Z = network.predict(mesh_points)
if isinstance(Z, jnp.ndarray):
Z = np.array(Z)
Z = Z.reshape(xx.shape)
# Plot decision boundary
plt.figure(figsize=(8, 6))
plt.contourf(xx, yy, Z, cmap=plt.cm.RdYlBu_r, alpha=0.3)
# Plot training points
plt.scatter(X[:, 0], X[:, 1], c=y,
cmap=plt.cm.RdYlBu_r,
alpha=0.6,
edgecolors='gray')
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
plt.title(title)
if save_path:
plt.savefig(save_path)
if show:
plt.show()
else:
plt.close()
def plot_training_history(history: Dict[str, List[float]],
title: str = "Training History",
save_path: str = None, show: bool = True) -> None:
"""Plot training history metrics.
Args:
history: Dictionary of metrics
title: Plot title
save_path: Path to save plot to
show: Whether to display plot
"""
plt.figure(figsize=(10, 6))
# Convert JAX arrays to NumPy if needed
plot_history = {}
for metric, values in history.items():
if isinstance(values[0], jnp.ndarray):
plot_history[metric] = [float(v) for v in values]
else:
plot_history[metric] = values
for metric, values in plot_history.items():
plt.plot(values, label=metric)
plt.title(title)
plt.xlabel('Generation')
plt.ylabel('Value')
plt.legend()
plt.grid(True)
if save_path:
plt.savefig(save_path)
if show:
plt.show()
else:
plt.close()
|