"""Visualization utilities for NEAT networks.""" import matplotlib.pyplot as plt import networkx as nx import numpy as np import jax.numpy as jnp from typing import List, Dict, Any from .network import Network def plot_network_structure(network: Network, title: str = "Network Structure", save_path: str = None, show: bool = True) -> None: """Plot network structure using networkx. Args: network: Network to visualize title: Plot title save_path: Path to save plot to show: Whether to display plot """ # Create graph G = nx.DiGraph() # Add nodes input_nodes = network.get_input_nodes() hidden_nodes = network.get_hidden_nodes() output_nodes = network.get_output_nodes() # Position nodes in layers pos = {} # Input layer for i, node in enumerate(input_nodes): G.add_node(node, layer='input') pos[node] = (0, (i - len(input_nodes)/2) / max(1, len(input_nodes)-1)) # Hidden layer for i, node in enumerate(hidden_nodes): G.add_node(node, layer='hidden') pos[node] = (1, (i - len(hidden_nodes)/2) / max(1, len(hidden_nodes)-1)) # Output layer for i, node in enumerate(output_nodes): G.add_node(node, layer='output') pos[node] = (2, (i - len(output_nodes)/2) / max(1, len(output_nodes)-1)) # Add edges with weights connections = network.get_connections() for src, dst, weight in connections: # Convert JAX array to NumPy float if isinstance(weight, jnp.ndarray): weight = float(weight) G.add_edge(src, dst, weight=weight) # Draw network plt.figure(figsize=(8, 6)) # Draw nodes node_colors = ['lightblue' if G.nodes[n]['layer'] == 'input' else 'lightgreen' if G.nodes[n]['layer'] == 'hidden' else 'salmon' for n in G.nodes()] nx.draw_networkx_nodes(G, pos, node_color=node_colors) # Draw edges with weights as colors edges = G.edges() weights = [G[u][v]['weight'] for u, v in edges] # Normalize weights for coloring max_weight = max(abs(min(weights)), abs(max(weights))) if max_weight > 0: norm_weights = [(w + max_weight)/(2*max_weight) for w in weights] else: norm_weights = [0.5] * len(weights) # Default to middle color if all weights are 0 nx.draw_networkx_edges(G, pos, edge_color=norm_weights, edge_cmap=plt.cm.RdYlBu, width=2) # Add labels labels = {n: str(n) for n in G.nodes()} nx.draw_networkx_labels(G, pos, labels) plt.title(title) plt.axis('off') if save_path: plt.savefig(save_path) if show: plt.show() else: plt.close() def plot_decision_boundary(network: Network, X: np.ndarray, y: np.ndarray, title: str = "Decision Boundary", save_path: str = None, show: bool = True) -> None: """Plot decision boundary for 2D classification problem. Args: network: Trained network X: Input data (n_samples, 2) y: Labels title: Plot title save_path: Path to save plot to show: Whether to display plot """ # Convert JAX arrays to NumPy if isinstance(X, jnp.ndarray): X = np.array(X) if isinstance(y, jnp.ndarray): y = np.array(y) # Create mesh grid h = 0.02 # Step size x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) # Make predictions on mesh mesh_points = np.c_[xx.ravel(), yy.ravel()] Z = network.predict(mesh_points) if isinstance(Z, jnp.ndarray): Z = np.array(Z) Z = Z.reshape(xx.shape) # Plot decision boundary plt.figure(figsize=(8, 6)) plt.contourf(xx, yy, Z, cmap=plt.cm.RdYlBu_r, alpha=0.3) # Plot training points plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.RdYlBu_r, alpha=0.6, edgecolors='gray') plt.xlim(xx.min(), xx.max()) plt.ylim(yy.min(), yy.max()) plt.title(title) if save_path: plt.savefig(save_path) if show: plt.show() else: plt.close() def plot_training_history(history: Dict[str, List[float]], title: str = "Training History", save_path: str = None, show: bool = True) -> None: """Plot training history metrics. Args: history: Dictionary of metrics title: Plot title save_path: Path to save plot to show: Whether to display plot """ plt.figure(figsize=(10, 6)) # Convert JAX arrays to NumPy if needed plot_history = {} for metric, values in history.items(): if isinstance(values[0], jnp.ndarray): plot_history[metric] = [float(v) for v in values] else: plot_history[metric] = values for metric, values in plot_history.items(): plt.plot(values, label=metric) plt.title(title) plt.xlabel('Generation') plt.ylabel('Value') plt.legend() plt.grid(True) if save_path: plt.savefig(save_path) if show: plt.show() else: plt.close()