import jax.numpy as jnp class NEAlgorithm: """Base class for neuroevolution algorithms""" def __init__(self): self.gen = 0 self.pop = [] def ask(self) -> jnp.ndarray: """Return current population parameters""" raise NotImplementedError def tell(self, fitness_array: jnp.ndarray) -> None: """Update population based on fitness values""" raise NotImplementedError