|
from flax.core import FrozenDict |
|
import flax.linen as nn |
|
import jax |
|
import jax.numpy as jnp |
|
from functools import partial |
|
|
|
|
|
|
|
|
|
|
|
def scale(state: jnp.ndarray) -> jnp.ndarray: |
|
return state / 255.0 |
|
|
|
|
|
class Torso(nn.Module): |
|
initialization_type: str |
|
|
|
@nn.compact |
|
def __call__(self, state): |
|
if self.initialization_type == "dqn": |
|
initializer = nn.initializers.variance_scaling(scale=1.0, mode="fan_avg", distribution="truncated_normal") |
|
elif self.initialization_type == "iqn": |
|
initializer = nn.initializers.variance_scaling( |
|
scale=1.0 / jnp.sqrt(3.0), mode="fan_in", distribution="uniform" |
|
) |
|
|
|
x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4), kernel_init=initializer)(state) |
|
x = nn.relu(x) |
|
x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2), kernel_init=initializer)(x) |
|
x = nn.relu(x) |
|
x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1), kernel_init=initializer)(x) |
|
x = nn.relu(x) |
|
|
|
return x.flatten() |
|
|
|
|
|
class Head(nn.Module): |
|
n_actions: int |
|
initialization_type: str |
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
if self.initialization_type == "dqn": |
|
initializer = nn.initializers.variance_scaling(scale=1.0, mode="fan_avg", distribution="truncated_normal") |
|
elif self.initialization_type == "iqn": |
|
initializer = nn.initializers.variance_scaling( |
|
scale=1.0 / jnp.sqrt(3.0), mode="fan_in", distribution="uniform" |
|
) |
|
|
|
x = nn.Dense(features=512, kernel_init=initializer)(x) |
|
x = nn.relu(x) |
|
|
|
return nn.Dense(features=self.n_actions, kernel_init=initializer)(x) |
|
|
|
|
|
class QuantileEmbedding(nn.Module): |
|
n_features: int = 7744 |
|
quantile_embedding_dim: int = 64 |
|
|
|
@nn.compact |
|
def __call__(self, key, n_quantiles): |
|
initializer = nn.initializers.variance_scaling(scale=1.0 / jnp.sqrt(3.0), mode="fan_in", distribution="uniform") |
|
|
|
quantiles = jax.random.uniform(key, shape=(n_quantiles, 1)) |
|
arange = jnp.arange(1, self.quantile_embedding_dim + 1).reshape((1, self.quantile_embedding_dim)) |
|
|
|
quantile_embedding = nn.Dense(features=self.n_features, kernel_init=initializer)( |
|
jnp.cos(jnp.pi * quantiles @ arange) |
|
) |
|
|
|
return (nn.relu(quantile_embedding), jnp.squeeze(quantiles, axis=1)) |
|
|
|
|
|
|
|
|
|
|
|
class AtariSharediDQNNet: |
|
def __init__(self, n_actions: int) -> None: |
|
self.n_heads = 5 |
|
self.n_actions = n_actions |
|
self.torso = Torso("dqn") |
|
self.head = Head(self.n_actions, "dqn") |
|
|
|
def apply(self, params: FrozenDict, idx_head: int, state: jnp.ndarray) -> jnp.ndarray: |
|
feature = self.torso.apply( |
|
params[f"torso_params_{min(idx_head, 1)}"], |
|
state, |
|
) |
|
|
|
return self.head.apply(params[f"head_params_{idx_head}"], feature) |
|
|
|
|
|
class AtariiDQN: |
|
def __init__(self, n_actions: int, idx_head: int) -> None: |
|
self.network = AtariSharediDQNNet(n_actions) |
|
self.idx_head = idx_head |
|
|
|
@partial(jax.jit, static_argnames="self") |
|
def best_action(self, params: FrozenDict, state: jnp.ndarray, key: jax.random.PRNGKeyArray) -> jnp.int8: |
|
return jnp.argmax(self.network.apply(params, self.idx_head, scale(state))).astype(jnp.int8) |
|
|
|
|
|
|
|
|
|
|
|
class AtariSharediIQNNet: |
|
def __init__(self, n_actions: int) -> None: |
|
self.n_heads = 4 |
|
self.n_actions = n_actions |
|
self.torso = Torso("iqn") |
|
self.quantile_embedding = QuantileEmbedding() |
|
self.head = Head(self.n_actions, "iqn") |
|
|
|
def apply( |
|
self, params: FrozenDict, idx_head: int, state: jnp.ndarray, key: jax.random.PRNGKey, n_quantiles: int |
|
) -> jnp.ndarray: |
|
|
|
state_feature = self.torso.apply( |
|
jax.tree_util.tree_map( |
|
lambda param: param[jax.lax.cond(idx_head >= 1, lambda: 1, lambda: 0)], params["torso_params"] |
|
), |
|
state, |
|
) |
|
|
|
|
|
quantiles_feature, _ = self.quantile_embedding.apply( |
|
jax.tree_util.tree_map( |
|
lambda param: param[jax.lax.cond(idx_head >= 1, lambda: 1, lambda: 0)], params["quantiles_params"] |
|
), |
|
key, |
|
n_quantiles, |
|
) |
|
|
|
|
|
feature = jax.vmap( |
|
lambda quantile_feature_, state_feature_: quantile_feature_ * state_feature_, in_axes=(0, None) |
|
)(quantiles_feature, state_feature) |
|
|
|
return self.head.apply( |
|
jax.tree_util.tree_map(lambda param: param[idx_head], params["head_params"]), feature |
|
) |
|
|
|
|
|
class AtariiIQN: |
|
def __init__(self, n_actions: int, idx_head: int) -> None: |
|
self.network = AtariSharediIQNNet(n_actions) |
|
self.idx_head = idx_head |
|
self.n_quantiles_policy = 32 |
|
|
|
@partial(jax.jit, static_argnames="self") |
|
def best_action(self, params: FrozenDict, state: jnp.ndarray, key: jax.random.PRNGKeyArray) -> jnp.int8: |
|
|
|
q_quantiles = self.network.apply(params, self.idx_head, scale(state), key, self.n_quantiles_policy) |
|
q_values = jnp.mean(q_quantiles, axis=0) |
|
|
|
return jnp.argmax(q_values).astype(jnp.int8) |
|
|