from flax.core import FrozenDict import flax.linen as nn import jax import jax.numpy as jnp from functools import partial # --- Base functions --- def scale(state: jnp.ndarray) -> jnp.ndarray: return state / 255.0 class Torso(nn.Module): initialization_type: str @nn.compact def __call__(self, state): if self.initialization_type == "dqn": initializer = nn.initializers.variance_scaling(scale=1.0, mode="fan_avg", distribution="truncated_normal") elif self.initialization_type == "iqn": initializer = nn.initializers.variance_scaling( scale=1.0 / jnp.sqrt(3.0), mode="fan_in", distribution="uniform" ) x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4), kernel_init=initializer)(state) x = nn.relu(x) x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2), kernel_init=initializer)(x) x = nn.relu(x) x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1), kernel_init=initializer)(x) x = nn.relu(x) return x.flatten() class Head(nn.Module): n_actions: int initialization_type: str @nn.compact def __call__(self, x): if self.initialization_type == "dqn": initializer = nn.initializers.variance_scaling(scale=1.0, mode="fan_avg", distribution="truncated_normal") elif self.initialization_type == "iqn": initializer = nn.initializers.variance_scaling( scale=1.0 / jnp.sqrt(3.0), mode="fan_in", distribution="uniform" ) x = nn.Dense(features=512, kernel_init=initializer)(x) x = nn.relu(x) return nn.Dense(features=self.n_actions, kernel_init=initializer)(x) class QuantileEmbedding(nn.Module): n_features: int = 7744 quantile_embedding_dim: int = 64 @nn.compact def __call__(self, key, n_quantiles): initializer = nn.initializers.variance_scaling(scale=1.0 / jnp.sqrt(3.0), mode="fan_in", distribution="uniform") quantiles = jax.random.uniform(key, shape=(n_quantiles, 1)) arange = jnp.arange(1, self.quantile_embedding_dim + 1).reshape((1, self.quantile_embedding_dim)) quantile_embedding = nn.Dense(features=self.n_features, kernel_init=initializer)( jnp.cos(jnp.pi * quantiles @ arange) ) # output (n_quantiles, n_features) | (n_quantiles) return (nn.relu(quantile_embedding), jnp.squeeze(quantiles, axis=1)) # --- i-DQN networks --- class AtariSharediDQNNet: def __init__(self, n_actions: int) -> None: self.n_heads = 5 self.n_actions = n_actions self.torso = Torso("dqn") self.head = Head(self.n_actions, "dqn") def apply(self, params: FrozenDict, idx_head: int, state: jnp.ndarray) -> jnp.ndarray: feature = self.torso.apply( params[f"torso_params_{min(idx_head, 1)}"], state, ) return self.head.apply(params[f"head_params_{idx_head}"], feature) class AtariiDQN: def __init__(self, n_actions: int, idx_head: int) -> None: self.network = AtariSharediDQNNet(n_actions) self.idx_head = idx_head @partial(jax.jit, static_argnames="self") def best_action(self, params: FrozenDict, state: jnp.ndarray, key: jax.random.PRNGKeyArray) -> jnp.int8: return jnp.argmax(self.network.apply(params, self.idx_head, scale(state))).astype(jnp.int8) # --- i-IQN networks --- class AtariSharediIQNNet: def __init__(self, n_actions: int) -> None: self.n_heads = 4 self.n_actions = n_actions self.torso = Torso("iqn") self.quantile_embedding = QuantileEmbedding() self.head = Head(self.n_actions, "iqn") def apply( self, params: FrozenDict, idx_head: int, state: jnp.ndarray, key: jax.random.PRNGKey, n_quantiles: int ) -> jnp.ndarray: # output (n_features) state_feature = self.torso.apply( jax.tree_util.tree_map( lambda param: param[jax.lax.cond(idx_head >= 1, lambda: 1, lambda: 0)], params["torso_params"] ), state, ) # output (n_quantiles, n_features) quantiles_feature, _ = self.quantile_embedding.apply( jax.tree_util.tree_map( lambda param: param[jax.lax.cond(idx_head >= 1, lambda: 1, lambda: 0)], params["quantiles_params"] ), key, n_quantiles, ) # mapping over the quantiles | output (n_quantiles, n_features) feature = jax.vmap( lambda quantile_feature_, state_feature_: quantile_feature_ * state_feature_, in_axes=(0, None) )(quantiles_feature, state_feature) return self.head.apply( jax.tree_util.tree_map(lambda param: param[idx_head], params["head_params"]), feature ) # output (n_quantiles, n_actions) class AtariiIQN: def __init__(self, n_actions: int, idx_head: int) -> None: self.network = AtariSharediIQNNet(n_actions) self.idx_head = idx_head self.n_quantiles_policy = 32 @partial(jax.jit, static_argnames="self") def best_action(self, params: FrozenDict, state: jnp.ndarray, key: jax.random.PRNGKeyArray) -> jnp.int8: # output (n_quantiles, n_actions) q_quantiles = self.network.apply(params, self.idx_head, scale(state), key, self.n_quantiles_policy) q_values = jnp.mean(q_quantiles, axis=0) return jnp.argmax(q_values).astype(jnp.int8)