File size: 5,492 Bytes
0232428
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
"""Visualization utilities for NEAT networks."""

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import jax.numpy as jnp
from typing import List, Dict, Any
from .network import Network

def plot_network_structure(network: Network, title: str = "Network Structure",
                         save_path: str = None, show: bool = True) -> None:
    """Plot network structure using networkx.
    
    Args:
        network: Network to visualize
        title: Plot title
        save_path: Path to save plot to
        show: Whether to display plot
    """
    # Create graph
    G = nx.DiGraph()
    
    # Add nodes
    input_nodes = network.get_input_nodes()
    hidden_nodes = network.get_hidden_nodes()
    output_nodes = network.get_output_nodes()
    
    # Position nodes in layers
    pos = {}
    
    # Input layer
    for i, node in enumerate(input_nodes):
        G.add_node(node, layer='input')
        pos[node] = (0, (i - len(input_nodes)/2) / max(1, len(input_nodes)-1))
        
    # Hidden layer
    for i, node in enumerate(hidden_nodes):
        G.add_node(node, layer='hidden')
        pos[node] = (1, (i - len(hidden_nodes)/2) / max(1, len(hidden_nodes)-1))
        
    # Output layer
    for i, node in enumerate(output_nodes):
        G.add_node(node, layer='output')
        pos[node] = (2, (i - len(output_nodes)/2) / max(1, len(output_nodes)-1))
    
    # Add edges with weights
    connections = network.get_connections()
    for src, dst, weight in connections:
        # Convert JAX array to NumPy float
        if isinstance(weight, jnp.ndarray):
            weight = float(weight)
        G.add_edge(src, dst, weight=weight)
    
    # Draw network
    plt.figure(figsize=(8, 6))
    
    # Draw nodes
    node_colors = ['lightblue' if G.nodes[n]['layer'] == 'input' else
                  'lightgreen' if G.nodes[n]['layer'] == 'hidden' else
                  'salmon' for n in G.nodes()]
    nx.draw_networkx_nodes(G, pos, node_color=node_colors)
    
    # Draw edges with weights as colors
    edges = G.edges()
    weights = [G[u][v]['weight'] for u, v in edges]
    
    # Normalize weights for coloring
    max_weight = max(abs(min(weights)), abs(max(weights)))
    if max_weight > 0:
        norm_weights = [(w + max_weight)/(2*max_weight) for w in weights]
    else:
        norm_weights = [0.5] * len(weights)  # Default to middle color if all weights are 0
    
    nx.draw_networkx_edges(G, pos, edge_color=norm_weights, 
                          edge_cmap=plt.cm.RdYlBu, width=2)
    
    # Add labels
    labels = {n: str(n) for n in G.nodes()}
    nx.draw_networkx_labels(G, pos, labels)
    
    plt.title(title)
    plt.axis('off')
    
    if save_path:
        plt.savefig(save_path)
    
    if show:
        plt.show()
    else:
        plt.close()

def plot_decision_boundary(network: Network, X: np.ndarray, y: np.ndarray,
                         title: str = "Decision Boundary",
                         save_path: str = None, show: bool = True) -> None:
    """Plot decision boundary for 2D classification problem.
    
    Args:
        network: Trained network
        X: Input data (n_samples, 2)
        y: Labels
        title: Plot title
        save_path: Path to save plot to
        show: Whether to display plot
    """
    # Convert JAX arrays to NumPy
    if isinstance(X, jnp.ndarray):
        X = np.array(X)
    if isinstance(y, jnp.ndarray):
        y = np.array(y)
    
    # Create mesh grid
    h = 0.02  # Step size
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                        np.arange(y_min, y_max, h))
    
    # Make predictions on mesh
    mesh_points = np.c_[xx.ravel(), yy.ravel()]
    Z = network.predict(mesh_points)
    if isinstance(Z, jnp.ndarray):
        Z = np.array(Z)
    Z = Z.reshape(xx.shape)
    
    # Plot decision boundary
    plt.figure(figsize=(8, 6))
    plt.contourf(xx, yy, Z, cmap=plt.cm.RdYlBu_r, alpha=0.3)
    
    # Plot training points
    plt.scatter(X[:, 0], X[:, 1], c=y, 
               cmap=plt.cm.RdYlBu_r,
               alpha=0.6,
               edgecolors='gray')
    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())
    
    plt.title(title)
    
    if save_path:
        plt.savefig(save_path)
    
    if show:
        plt.show()
    else:
        plt.close()

def plot_training_history(history: Dict[str, List[float]], 
                         title: str = "Training History",
                         save_path: str = None, show: bool = True) -> None:
    """Plot training history metrics.
    
    Args:
        history: Dictionary of metrics
        title: Plot title
        save_path: Path to save plot to
        show: Whether to display plot
    """
    plt.figure(figsize=(10, 6))
    
    # Convert JAX arrays to NumPy if needed
    plot_history = {}
    for metric, values in history.items():
        if isinstance(values[0], jnp.ndarray):
            plot_history[metric] = [float(v) for v in values]
        else:
            plot_history[metric] = values
    
    for metric, values in plot_history.items():
        plt.plot(values, label=metric)
    
    plt.title(title)
    plt.xlabel('Generation')
    plt.ylabel('Value')
    plt.legend()
    plt.grid(True)
    
    if save_path:
        plt.savefig(save_path)
    
    if show:
        plt.show()
    else:
        plt.close()