neat / neat\datasets.py
eyad-silx's picture
Upload neat\datasets.py with huggingface_hub
80f8293 verified
raw
history blame
7.87 kB
"""Dataset generation functions for testing BackpropNEAT."""
import numpy as np
import jax.numpy as jnp
def generate_xor_data(n_samples: int = 200, complexity: float = 1.0) -> tuple:
"""Generate complex XOR dataset with multiple clusters and rotations.
Args:
n_samples: Number of samples per quadrant
complexity: Controls the complexity of the pattern (rotation and noise)
Returns:
Tuple of (features, labels)
"""
points = []
labels = []
# Generate multiple clusters per quadrant
n_clusters = 3
samples_per_cluster = n_samples // n_clusters
for cluster in range(n_clusters):
# Add rotation to each subsequent cluster
rotation = complexity * cluster * np.pi / 6 # 30 degree rotation per cluster
# Define cluster centers with gaps
centers = [
# (x, y, radius, label)
(-0.7 - 0.3*cluster, -0.7 - 0.3*cluster, 0.2, -1), # Bottom-left
(0.7 + 0.3*cluster, 0.7 + 0.3*cluster, 0.2, -1), # Top-right
(-0.7 - 0.3*cluster, 0.7 + 0.3*cluster, 0.2, 1), # Top-left
(0.7 + 0.3*cluster, -0.7 - 0.3*cluster, 0.2, 1), # Bottom-right
]
for cx, cy, radius, label in centers:
# Generate points in a circle around center
theta = np.random.uniform(0, 2*np.pi, samples_per_cluster)
r = np.random.uniform(0, radius, samples_per_cluster)
# Convert to cartesian coordinates
x = r * np.cos(theta)
y = r * np.sin(theta)
# Apply rotation
x_rot = x * np.cos(rotation) - y * np.sin(rotation)
y_rot = x * np.sin(rotation) + y * np.cos(rotation)
# Add cluster center and noise
x = cx + x_rot + np.random.normal(0, 0.05, samples_per_cluster)
y = cy + y_rot + np.random.normal(0, 0.05, samples_per_cluster)
# Add points
cluster_points = np.column_stack([x, y])
points.append(cluster_points)
labels.extend([label] * samples_per_cluster)
# Convert to arrays
X = np.vstack(points)
y = np.array(labels, dtype=np.float32)
# Add global rotation
theta = complexity * np.pi / 4 # 45 degree global rotation
rotation_matrix = np.array([
[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]
])
X = X @ rotation_matrix
# Shuffle data
perm = np.random.permutation(len(X))
X = X[perm]
y = y[perm]
return jnp.array(X), jnp.array(y)
def generate_circle_data(n_samples: int = 1000, noise: float = 0.1) -> tuple:
"""Generate circle classification dataset.
Args:
n_samples: Number of samples per class
noise: Standard deviation of Gaussian noise
Returns:
Tuple of (features, labels)
"""
# Generate random angles
theta = np.random.uniform(0, 2*np.pi, n_samples)
# Inner circle (class -1)
r_inner = 0.5 + np.random.normal(0, noise, n_samples)
X_inner = np.column_stack([
r_inner * np.cos(theta),
r_inner * np.sin(theta)
])
y_inner = np.full(n_samples, -1.0)
# Outer circle (class 1)
r_outer = 1.5 + np.random.normal(0, noise, n_samples)
X_outer = np.column_stack([
r_outer * np.cos(theta),
r_outer * np.sin(theta)
])
y_outer = np.full(n_samples, 1.0)
# Combine and shuffle
X = np.vstack([X_inner, X_outer])
y = np.hstack([y_inner, y_outer])
# Shuffle
perm = np.random.permutation(len(X))
return X[perm], y[perm]
def generate_spiral_dataset(n_points=1000, noise=0.1):
"""Generate a spiral dataset with rotation-invariant features."""
# Generate theta values with more points near the center
theta = np.sqrt(np.random.uniform(0, 1, n_points)) * 4 * np.pi
# Generate two spirals
data = []
labels = []
eps = 1e-8
for i in range(n_points):
# Base radius increases with theta
r_base = theta[i] / (4 * np.pi)
# Add noise that scales with radius
noise_scale = noise * (1 - np.exp(-2 * r_base))
for spiral_idx in range(2):
# Rotate second spiral by pi
angle = theta[i] + np.pi * spiral_idx
# Add controlled noise to radius and angle
r = r_base + np.random.normal(0, noise_scale)
angle_noise = np.random.normal(0, noise_scale * 0.1) # Less noise in angle
angle += angle_noise
# Calculate cartesian coordinates
x = r * np.cos(angle)
y = r * np.sin(angle)
# Calculate polar coordinates
r_point = np.sqrt(x*x + y*y)
theta_point = np.arctan2(y, x)
# Unwrap theta to handle multiple revolutions
theta_unwrapped = theta_point + 2 * np.pi * (angle // (2 * np.pi))
# Calculate spiral-specific features
# 1. Local curvature (how much the spiral curves at this point)
curvature = 1 / (r_point + eps)
# 2. Spiral phase (position along spiral revolution)
phase = theta_unwrapped % (2 * np.pi) / (2 * np.pi)
# 3. Radial velocity (how fast radius changes with angle)
dr_dtheta = 1 / (4 * np.pi)
# 4. Normalized angular position (accounts for multiple revolutions)
angular_pos = theta_unwrapped / (4 * np.pi)
# 5. Spiral tightness (local measure of how tight the spiral is)
tightness = r_point / (theta_unwrapped + eps)
# 6. Relative position features (help distinguish between spirals)
# Distance to other spiral
other_angle = angle + np.pi
other_x = r * np.cos(other_angle)
other_y = r * np.sin(other_angle)
dist_to_other = np.sqrt((x - other_x)**2 + (y - other_y)**2)
# 7. Rotation-invariant features
sin_phase = np.sin(phase * 2 * np.pi)
cos_phase = np.cos(phase * 2 * np.pi)
# Combine features with careful normalization
features = np.array([
x / 2.0, # Normalize coordinates
y / 2.0,
r_point / 2.0, # Normalize radius
sin_phase, # Already normalized
cos_phase, # Already normalized
np.tanh(curvature * 2), # Normalize curvature
angular_pos / 2.0, # Normalize angular position
np.tanh(tightness), # Normalize tightness
np.tanh(dr_dtheta * 10), # Normalize radial velocity
dist_to_other / 4.0 # Normalize distance to other spiral
])
data.append(features)
labels.append(spiral_idx * 2 - 1) # Convert to [-1, 1]
return np.array(data), np.array(labels)
def generate_checkerboard_data(n_samples: int = 200) -> tuple:
"""Generate checkerboard dataset.
Args:
n_samples: Number of samples per class
Returns:
Tuple of (features, labels)
"""
# Generate random points
X = np.random.uniform(-2, 2, (n_samples * 2, 2))
# Assign labels based on checkerboard pattern
y = np.zeros(n_samples * 2)
for i in range(len(X)):
x1, x2 = X[i]
y[i] = 1 if (int(np.floor(x1)) + int(np.floor(x2))) % 2 == 0 else 0
return jnp.array(X), jnp.array(y)
# Export dataset functions
__all__ = ['generate_xor_data', 'generate_circle_data', 'generate_spiral_dataset',
'generate_checkerboard_data']