"""Visualization utilities for NEAT networks and training progress.""" import os import numpy as np import matplotlib.pyplot as plt import networkx as nx from typing import List, Dict, Any import imageio from IPython.display import HTML from neat.network import Network from neat.genome import Genome def draw_network(network: Network, save_path: str = None) -> None: """Draw a neural network visualization using networkx and matplotlib. Args: network: The network to visualize save_path: Optional path to save the visualization """ # Create directed graph G = nx.DiGraph() # Track node types and positions node_types = {} node_positions = {} # Collect all unique nodes from connections all_nodes = set() for conn in network.connection_genes: if conn.enabled: all_nodes.add(conn.source) all_nodes.add(conn.target) # Calculate layout parameters layer_spacing = 2.0 # Add input nodes (leftmost layer) input_nodes = set(range(network.input_size)) input_y = np.linspace(-1, 1, len(input_nodes)) for i, node in enumerate(sorted(input_nodes)): node_id = str(node) node_types[node_id] = 'input' node_positions[node_id] = np.array([0, input_y[i]]) G.add_node(node_id) all_nodes.discard(node) # Remove from remaining nodes # Add output nodes (rightmost layer) output_start = len(network.node_genes) - network.output_size output_nodes = set(range(output_start, len(network.node_genes))) output_y = np.linspace(-1, 1, len(output_nodes)) for i, node in enumerate(sorted(output_nodes)): node_id = str(node) node_types[node_id] = 'output' node_positions[node_id] = np.array([layer_spacing, output_y[i]]) G.add_node(node_id) all_nodes.discard(node) # Add hidden nodes (middle layer) hidden_nodes = all_nodes # Remaining nodes are hidden if hidden_nodes: hidden_y = np.linspace(-1, 1, len(hidden_nodes)) for i, node in enumerate(sorted(hidden_nodes)): node_id = str(node) node_types[node_id] = 'hidden' node_positions[node_id] = np.array([layer_spacing/2, hidden_y[i]]) G.add_node(node_id) # Add connections for conn in network.connection_genes: if conn.enabled: G.add_edge(str(conn.source), str(conn.target), weight=conn.weight) # Draw the network plt.figure(figsize=(8, 6)) # Draw nodes for node, (x, y) in node_positions.items(): node_type = node_types[node] if node_type == 'input': color = 'lightblue' elif node_type == 'hidden': color = 'gray' else: # output color = 'lightgreen' plt.scatter(x, y, c=color, s=500, zorder=2) plt.text(x, y, node, horizontalalignment='center', verticalalignment='center') # Draw edges edge_weights = [G[u][v]['weight'] for u, v in G.edges()] pos = node_positions nx.draw_networkx_edges(G, pos, edge_color='gray', width=1, alpha=0.5, arrows=True, arrowsize=10, edge_cmap=plt.cm.RdYlBu, edge_vmin=-1, edge_vmax=1, connectionstyle="arc3,rad=0.2") plt.title("Neural Network Architecture") plt.axis('equal') plt.axis('off') if save_path: plt.savefig(save_path, bbox_inches='tight', dpi=300) plt.close() else: plt.show() def plot_training_history(history: Dict[str, List[float]], save_path: str = None) -> None: """Plot training metrics over generations. Args: history: Dictionary containing lists of metrics per generation save_path: Optional path to save the plot """ plt.figure(figsize=(12, 8)) # Plot fitness metrics if 'best_fitness' in history: plt.plot(history['best_fitness'], label='Best Fitness', color='green') if 'avg_fitness' in history: plt.plot(history['avg_fitness'], label='Average Fitness', color='blue') # Plot species count if available if 'species_count' in history: ax2 = plt.twinx() ax2.plot(history['species_count'], label='Species Count', color='red', linestyle='--') ax2.set_ylabel('Number of Species') plt.xlabel('Generation') plt.ylabel('Fitness') plt.title('Training Progress') plt.legend() if save_path: plt.savefig(save_path, bbox_inches='tight') plt.close() else: plt.show() def create_gameplay_gif(frames: List[np.ndarray], output_path: str, fps: int = 30) -> None: """Create a GIF from gameplay frames. Args: frames: List of frames as numpy arrays output_path: Path to save the GIF fps: Frames per second for the GIF """ # Ensure output directory exists os.makedirs(os.path.dirname(output_path), exist_ok=True) # Save frames as GIF imageio.mimsave(output_path, frames, fps=fps) def plot_species_complexity(species_stats: List[Dict[str, Any]], save_path: str = None) -> None: """Plot the complexity of species over generations. Args: species_stats: List of dictionaries containing species statistics per generation save_path: Optional path to save the plot """ plt.figure(figsize=(12, 8)) generations = range(len(species_stats)) avg_nodes = [stats['avg_nodes'] for stats in species_stats] avg_connections = [stats['avg_connections'] for stats in species_stats] plt.plot(generations, avg_nodes, label='Average Nodes', color='blue') plt.plot(generations, avg_connections, label='Average Connections', color='green') plt.xlabel('Generation') plt.ylabel('Count') plt.title('Network Complexity Over Time') plt.legend() if save_path: plt.savefig(save_path, bbox_inches='tight') plt.close() else: plt.show() def plot_activation_distribution(genomes: List[Genome], save_path: str = None) -> None: """Plot the distribution of activation functions across the population. Args: genomes: List of genomes to analyze save_path: Optional path to save the plot """ activation_counts = {} # Count activation functions for genome in genomes: for node in genome.nodes.values(): activation_name = node.activation.__name__ if hasattr(node.activation, '__name__') else str(node.activation) activation_counts[activation_name] = activation_counts.get(activation_name, 0) + 1 # Create bar plot plt.figure(figsize=(10, 6)) plt.bar(activation_counts.keys(), activation_counts.values()) plt.xticks(rotation=45) plt.xlabel('Activation Function') plt.ylabel('Count') plt.title('Distribution of Activation Functions') if save_path: plt.savefig(save_path, bbox_inches='tight') plt.close() else: plt.show()