neat / neat\visualization.py
eyad-silx's picture
Upload neat\visualization.py with huggingface_hub
0232428 verified
"""Visualization utilities for NEAT networks."""
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import jax.numpy as jnp
from typing import List, Dict, Any
from .network import Network
def plot_network_structure(network: Network, title: str = "Network Structure",
save_path: str = None, show: bool = True) -> None:
"""Plot network structure using networkx.
Args:
network: Network to visualize
title: Plot title
save_path: Path to save plot to
show: Whether to display plot
"""
# Create graph
G = nx.DiGraph()
# Add nodes
input_nodes = network.get_input_nodes()
hidden_nodes = network.get_hidden_nodes()
output_nodes = network.get_output_nodes()
# Position nodes in layers
pos = {}
# Input layer
for i, node in enumerate(input_nodes):
G.add_node(node, layer='input')
pos[node] = (0, (i - len(input_nodes)/2) / max(1, len(input_nodes)-1))
# Hidden layer
for i, node in enumerate(hidden_nodes):
G.add_node(node, layer='hidden')
pos[node] = (1, (i - len(hidden_nodes)/2) / max(1, len(hidden_nodes)-1))
# Output layer
for i, node in enumerate(output_nodes):
G.add_node(node, layer='output')
pos[node] = (2, (i - len(output_nodes)/2) / max(1, len(output_nodes)-1))
# Add edges with weights
connections = network.get_connections()
for src, dst, weight in connections:
# Convert JAX array to NumPy float
if isinstance(weight, jnp.ndarray):
weight = float(weight)
G.add_edge(src, dst, weight=weight)
# Draw network
plt.figure(figsize=(8, 6))
# Draw nodes
node_colors = ['lightblue' if G.nodes[n]['layer'] == 'input' else
'lightgreen' if G.nodes[n]['layer'] == 'hidden' else
'salmon' for n in G.nodes()]
nx.draw_networkx_nodes(G, pos, node_color=node_colors)
# Draw edges with weights as colors
edges = G.edges()
weights = [G[u][v]['weight'] for u, v in edges]
# Normalize weights for coloring
max_weight = max(abs(min(weights)), abs(max(weights)))
if max_weight > 0:
norm_weights = [(w + max_weight)/(2*max_weight) for w in weights]
else:
norm_weights = [0.5] * len(weights) # Default to middle color if all weights are 0
nx.draw_networkx_edges(G, pos, edge_color=norm_weights,
edge_cmap=plt.cm.RdYlBu, width=2)
# Add labels
labels = {n: str(n) for n in G.nodes()}
nx.draw_networkx_labels(G, pos, labels)
plt.title(title)
plt.axis('off')
if save_path:
plt.savefig(save_path)
if show:
plt.show()
else:
plt.close()
def plot_decision_boundary(network: Network, X: np.ndarray, y: np.ndarray,
title: str = "Decision Boundary",
save_path: str = None, show: bool = True) -> None:
"""Plot decision boundary for 2D classification problem.
Args:
network: Trained network
X: Input data (n_samples, 2)
y: Labels
title: Plot title
save_path: Path to save plot to
show: Whether to display plot
"""
# Convert JAX arrays to NumPy
if isinstance(X, jnp.ndarray):
X = np.array(X)
if isinstance(y, jnp.ndarray):
y = np.array(y)
# Create mesh grid
h = 0.02 # Step size
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
np.arange(y_min, y_max, h))
# Make predictions on mesh
mesh_points = np.c_[xx.ravel(), yy.ravel()]
Z = network.predict(mesh_points)
if isinstance(Z, jnp.ndarray):
Z = np.array(Z)
Z = Z.reshape(xx.shape)
# Plot decision boundary
plt.figure(figsize=(8, 6))
plt.contourf(xx, yy, Z, cmap=plt.cm.RdYlBu_r, alpha=0.3)
# Plot training points
plt.scatter(X[:, 0], X[:, 1], c=y,
cmap=plt.cm.RdYlBu_r,
alpha=0.6,
edgecolors='gray')
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
plt.title(title)
if save_path:
plt.savefig(save_path)
if show:
plt.show()
else:
plt.close()
def plot_training_history(history: Dict[str, List[float]],
title: str = "Training History",
save_path: str = None, show: bool = True) -> None:
"""Plot training history metrics.
Args:
history: Dictionary of metrics
title: Plot title
save_path: Path to save plot to
show: Whether to display plot
"""
plt.figure(figsize=(10, 6))
# Convert JAX arrays to NumPy if needed
plot_history = {}
for metric, values in history.items():
if isinstance(values[0], jnp.ndarray):
plot_history[metric] = [float(v) for v in values]
else:
plot_history[metric] = values
for metric, values in plot_history.items():
plt.plot(values, label=metric)
plt.title(title)
plt.xlabel('Generation')
plt.ylabel('Value')
plt.legend()
plt.grid(True)
if save_path:
plt.savefig(save_path)
if show:
plt.show()
else:
plt.close()