|
"""Visualization utilities for NEAT networks.""" |
|
|
|
import matplotlib.pyplot as plt |
|
import networkx as nx |
|
import numpy as np |
|
import jax.numpy as jnp |
|
from typing import List, Dict, Any |
|
from .network import Network |
|
|
|
def plot_network_structure(network: Network, title: str = "Network Structure", |
|
save_path: str = None, show: bool = True) -> None: |
|
"""Plot network structure using networkx. |
|
|
|
Args: |
|
network: Network to visualize |
|
title: Plot title |
|
save_path: Path to save plot to |
|
show: Whether to display plot |
|
""" |
|
|
|
G = nx.DiGraph() |
|
|
|
|
|
input_nodes = network.get_input_nodes() |
|
hidden_nodes = network.get_hidden_nodes() |
|
output_nodes = network.get_output_nodes() |
|
|
|
|
|
pos = {} |
|
|
|
|
|
for i, node in enumerate(input_nodes): |
|
G.add_node(node, layer='input') |
|
pos[node] = (0, (i - len(input_nodes)/2) / max(1, len(input_nodes)-1)) |
|
|
|
|
|
for i, node in enumerate(hidden_nodes): |
|
G.add_node(node, layer='hidden') |
|
pos[node] = (1, (i - len(hidden_nodes)/2) / max(1, len(hidden_nodes)-1)) |
|
|
|
|
|
for i, node in enumerate(output_nodes): |
|
G.add_node(node, layer='output') |
|
pos[node] = (2, (i - len(output_nodes)/2) / max(1, len(output_nodes)-1)) |
|
|
|
|
|
connections = network.get_connections() |
|
for src, dst, weight in connections: |
|
|
|
if isinstance(weight, jnp.ndarray): |
|
weight = float(weight) |
|
G.add_edge(src, dst, weight=weight) |
|
|
|
|
|
plt.figure(figsize=(8, 6)) |
|
|
|
|
|
node_colors = ['lightblue' if G.nodes[n]['layer'] == 'input' else |
|
'lightgreen' if G.nodes[n]['layer'] == 'hidden' else |
|
'salmon' for n in G.nodes()] |
|
nx.draw_networkx_nodes(G, pos, node_color=node_colors) |
|
|
|
|
|
edges = G.edges() |
|
weights = [G[u][v]['weight'] for u, v in edges] |
|
|
|
|
|
max_weight = max(abs(min(weights)), abs(max(weights))) |
|
if max_weight > 0: |
|
norm_weights = [(w + max_weight)/(2*max_weight) for w in weights] |
|
else: |
|
norm_weights = [0.5] * len(weights) |
|
|
|
nx.draw_networkx_edges(G, pos, edge_color=norm_weights, |
|
edge_cmap=plt.cm.RdYlBu, width=2) |
|
|
|
|
|
labels = {n: str(n) for n in G.nodes()} |
|
nx.draw_networkx_labels(G, pos, labels) |
|
|
|
plt.title(title) |
|
plt.axis('off') |
|
|
|
if save_path: |
|
plt.savefig(save_path) |
|
|
|
if show: |
|
plt.show() |
|
else: |
|
plt.close() |
|
|
|
def plot_decision_boundary(network: Network, X: np.ndarray, y: np.ndarray, |
|
title: str = "Decision Boundary", |
|
save_path: str = None, show: bool = True) -> None: |
|
"""Plot decision boundary for 2D classification problem. |
|
|
|
Args: |
|
network: Trained network |
|
X: Input data (n_samples, 2) |
|
y: Labels |
|
title: Plot title |
|
save_path: Path to save plot to |
|
show: Whether to display plot |
|
""" |
|
|
|
if isinstance(X, jnp.ndarray): |
|
X = np.array(X) |
|
if isinstance(y, jnp.ndarray): |
|
y = np.array(y) |
|
|
|
|
|
h = 0.02 |
|
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 |
|
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 |
|
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), |
|
np.arange(y_min, y_max, h)) |
|
|
|
|
|
mesh_points = np.c_[xx.ravel(), yy.ravel()] |
|
Z = network.predict(mesh_points) |
|
if isinstance(Z, jnp.ndarray): |
|
Z = np.array(Z) |
|
Z = Z.reshape(xx.shape) |
|
|
|
|
|
plt.figure(figsize=(8, 6)) |
|
plt.contourf(xx, yy, Z, cmap=plt.cm.RdYlBu_r, alpha=0.3) |
|
|
|
|
|
plt.scatter(X[:, 0], X[:, 1], c=y, |
|
cmap=plt.cm.RdYlBu_r, |
|
alpha=0.6, |
|
edgecolors='gray') |
|
plt.xlim(xx.min(), xx.max()) |
|
plt.ylim(yy.min(), yy.max()) |
|
|
|
plt.title(title) |
|
|
|
if save_path: |
|
plt.savefig(save_path) |
|
|
|
if show: |
|
plt.show() |
|
else: |
|
plt.close() |
|
|
|
def plot_training_history(history: Dict[str, List[float]], |
|
title: str = "Training History", |
|
save_path: str = None, show: bool = True) -> None: |
|
"""Plot training history metrics. |
|
|
|
Args: |
|
history: Dictionary of metrics |
|
title: Plot title |
|
save_path: Path to save plot to |
|
show: Whether to display plot |
|
""" |
|
plt.figure(figsize=(10, 6)) |
|
|
|
|
|
plot_history = {} |
|
for metric, values in history.items(): |
|
if isinstance(values[0], jnp.ndarray): |
|
plot_history[metric] = [float(v) for v in values] |
|
else: |
|
plot_history[metric] = values |
|
|
|
for metric, values in plot_history.items(): |
|
plt.plot(values, label=metric) |
|
|
|
plt.title(title) |
|
plt.xlabel('Generation') |
|
plt.ylabel('Value') |
|
plt.legend() |
|
plt.grid(True) |
|
|
|
if save_path: |
|
plt.savefig(save_path) |
|
|
|
if show: |
|
plt.show() |
|
else: |
|
plt.close() |
|
|