"""Visualization utilities for NEAT networks.""" |
import matplotlib.pyplot as plt |
import networkx as nx |
import numpy as np |
import jax.numpy as jnp |
from typing import List, Dict, Any |
from .network import Network |
def plot_network_structure(network: Network, title: str = "Network Structure", |
save_path: str = None, show: bool = True) -> None: |
"""Plot network structure using networkx. |
Args: |
network: Network to visualize |
title: Plot title |
save_path: Path to save plot to |
show: Whether to display plot |
""" |
G = nx.DiGraph() |
input_nodes = network.get_input_nodes() |
hidden_nodes = network.get_hidden_nodes() |
output_nodes = network.get_output_nodes() |
pos = {} |
for i, node in enumerate(input_nodes): |
G.add_node(node, layer='input') |
pos[node] = (0, (i - len(input_nodes)/2) / max(1, len(input_nodes)-1)) |
for i, node in enumerate(hidden_nodes): |
G.add_node(node, layer='hidden') |
pos[node] = (1, (i - len(hidden_nodes)/2) / max(1, len(hidden_nodes)-1)) |
for i, node in enumerate(output_nodes): |
G.add_node(node, layer='output') |
pos[node] = (2, (i - len(output_nodes)/2) / max(1, len(output_nodes)-1)) |
connections = network.get_connections() |
for src, dst, weight in connections: |
if isinstance(weight, jnp.ndarray): |
weight = float(weight) |
G.add_edge(src, dst, weight=weight) |
plt.figure(figsize=(8, 6)) |
node_colors = ['lightblue' if G.nodes[n]['layer'] == 'input' else |
'lightgreen' if G.nodes[n]['layer'] == 'hidden' else |
'salmon' for n in G.nodes()] |
nx.draw_networkx_nodes(G, pos, node_color=node_colors) |
edges = G.edges() |
weights = [G[u][v]['weight'] for u, v in edges] |
max_weight = max(abs(min(weights)), abs(max(weights))) |
if max_weight > 0: |
norm_weights = [(w + max_weight)/(2*max_weight) for w in weights] |
else: |
norm_weights = [0.5] * len(weights) |
nx.draw_networkx_edges(G, pos, edge_color=norm_weights, |
edge_cmap=plt.cm.RdYlBu, width=2) |
labels = {n: str(n) for n in G.nodes()} |
nx.draw_networkx_labels(G, pos, labels) |
plt.title(title) |
plt.axis('off') |
if save_path: |
plt.savefig(save_path) |
if show: |
plt.show() |
else: |
plt.close() |
def plot_decision_boundary(network: Network, X: np.ndarray, y: np.ndarray, |
title: str = "Decision Boundary", |
save_path: str = None, show: bool = True) -> None: |
"""Plot decision boundary for 2D classification problem. |
Args: |
network: Trained network |
X: Input data (n_samples, 2) |
y: Labels |
title: Plot title |
save_path: Path to save plot to |
show: Whether to display plot |
""" |
if isinstance(X, jnp.ndarray): |
X = np.array(X) |
if isinstance(y, jnp.ndarray): |
y = np.array(y) |
h = 0.02 |
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 |
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 |
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), |
np.arange(y_min, y_max, h)) |
mesh_points = np.c_[xx.ravel(), yy.ravel()] |
Z = network.predict(mesh_points) |
if isinstance(Z, jnp.ndarray): |
Z = np.array(Z) |
Z = Z.reshape(xx.shape) |
plt.figure(figsize=(8, 6)) |
plt.contourf(xx, yy, Z, cmap=plt.cm.RdYlBu_r, alpha=0.3) |
plt.scatter(X[:, 0], X[:, 1], c=y, |
cmap=plt.cm.RdYlBu_r, |
alpha=0.6, |
edgecolors='gray') |
plt.xlim(xx.min(), xx.max()) |
plt.ylim(yy.min(), yy.max()) |
plt.title(title) |
if save_path: |
plt.savefig(save_path) |
if show: |
plt.show() |
else: |
plt.close() |
def plot_training_history(history: Dict[str, List[float]], |
title: str = "Training History", |
save_path: str = None, show: bool = True) -> None: |
"""Plot training history metrics. |
Args: |
history: Dictionary of metrics |
title: Plot title |
save_path: Path to save plot to |
show: Whether to display plot |
""" |
plt.figure(figsize=(10, 6)) |
plot_history = {} |
for metric, values in history.items(): |
if isinstance(values[0], jnp.ndarray): |
plot_history[metric] = [float(v) for v in values] |
else: |
plot_history[metric] = values |
for metric, values in plot_history.items(): |
plt.plot(values, label=metric) |
plt.title(title) |
plt.xlabel('Generation') |
plt.ylabel('Value') |
plt.legend() |
plt.grid(True) |
if save_path: |
plt.savefig(save_path) |
if show: |
plt.show() |
else: |
plt.close() |