import jax.numpy as jnp | |
class NEAlgorithm: | |
"""Base class for neuroevolution algorithms""" | |
def __init__(self): | |
self.gen = 0 | |
self.pop = [] | |
def ask(self) -> jnp.ndarray: | |
"""Return current population parameters""" | |
raise NotImplementedError | |
def tell(self, fitness_array: jnp.ndarray) -> None: | |
"""Update population based on fitness values""" | |
raise NotImplementedError | |