Atari_i-QN / networks.py
TheoVincent's picture
minimal example - code
4cb4fc3
raw
history blame
5.52 kB
from flax.core import FrozenDict
import flax.linen as nn
import jax
import jax.numpy as jnp
from functools import partial
# --- Base functions ---
def scale(state: jnp.ndarray) -> jnp.ndarray:
return state / 255.0
class Torso(nn.Module):
initialization_type: str
@nn.compact
def __call__(self, state):
if self.initialization_type == "dqn":
initializer = nn.initializers.variance_scaling(scale=1.0, mode="fan_avg", distribution="truncated_normal")
elif self.initialization_type == "iqn":
initializer = nn.initializers.variance_scaling(
scale=1.0 / jnp.sqrt(3.0), mode="fan_in", distribution="uniform"
)
x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4), kernel_init=initializer)(state)
x = nn.relu(x)
x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2), kernel_init=initializer)(x)
x = nn.relu(x)
x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1), kernel_init=initializer)(x)
x = nn.relu(x)
return x.flatten()
class Head(nn.Module):
n_actions: int
initialization_type: str
@nn.compact
def __call__(self, x):
if self.initialization_type == "dqn":
initializer = nn.initializers.variance_scaling(scale=1.0, mode="fan_avg", distribution="truncated_normal")
elif self.initialization_type == "iqn":
initializer = nn.initializers.variance_scaling(
scale=1.0 / jnp.sqrt(3.0), mode="fan_in", distribution="uniform"
)
x = nn.Dense(features=512, kernel_init=initializer)(x)
x = nn.relu(x)
return nn.Dense(features=self.n_actions, kernel_init=initializer)(x)
class QuantileEmbedding(nn.Module):
n_features: int = 7744
quantile_embedding_dim: int = 64
@nn.compact
def __call__(self, key, n_quantiles):
initializer = nn.initializers.variance_scaling(scale=1.0 / jnp.sqrt(3.0), mode="fan_in", distribution="uniform")
quantiles = jax.random.uniform(key, shape=(n_quantiles, 1))
arange = jnp.arange(1, self.quantile_embedding_dim + 1).reshape((1, self.quantile_embedding_dim))
quantile_embedding = nn.Dense(features=self.n_features, kernel_init=initializer)(
jnp.cos(jnp.pi * quantiles @ arange)
)
# output (n_quantiles, n_features) | (n_quantiles)
return (nn.relu(quantile_embedding), jnp.squeeze(quantiles, axis=1))
# --- i-DQN networks ---
class AtariSharediDQNNet:
def __init__(self, n_actions: int) -> None:
self.n_heads = 5
self.n_actions = n_actions
self.torso = Torso("dqn")
self.head = Head(self.n_actions, "dqn")
def apply(self, params: FrozenDict, idx_head: int, state: jnp.ndarray) -> jnp.ndarray:
feature = self.torso.apply(
params[f"torso_params_{min(idx_head, 1)}"],
state,
)
return self.head.apply(params[f"head_params_{idx_head}"], feature)
class AtariiDQN:
def __init__(self, n_actions: int, idx_head: int) -> None:
self.network = AtariSharediDQNNet(n_actions)
self.idx_head = idx_head
@partial(jax.jit, static_argnames="self")
def best_action(self, params: FrozenDict, state: jnp.ndarray, key: jax.random.PRNGKeyArray) -> jnp.int8:
return jnp.argmax(self.network.apply(params, self.idx_head, scale(state))).astype(jnp.int8)
# --- i-IQN networks ---
class AtariSharediIQNNet:
def __init__(self, n_actions: int) -> None:
self.n_heads = 4
self.n_actions = n_actions
self.torso = Torso("iqn")
self.quantile_embedding = QuantileEmbedding()
self.head = Head(self.n_actions, "iqn")
def apply(
self, params: FrozenDict, idx_head: int, state: jnp.ndarray, key: jax.random.PRNGKey, n_quantiles: int
) -> jnp.ndarray:
# output (n_features)
state_feature = self.torso.apply(
jax.tree_util.tree_map(
lambda param: param[jax.lax.cond(idx_head >= 1, lambda: 1, lambda: 0)], params["torso_params"]
),
state,
)
# output (n_quantiles, n_features)
quantiles_feature, _ = self.quantile_embedding.apply(
jax.tree_util.tree_map(
lambda param: param[jax.lax.cond(idx_head >= 1, lambda: 1, lambda: 0)], params["quantiles_params"]
),
key,
n_quantiles,
)
# mapping over the quantiles | output (n_quantiles, n_features)
feature = jax.vmap(
lambda quantile_feature_, state_feature_: quantile_feature_ * state_feature_, in_axes=(0, None)
)(quantiles_feature, state_feature)
return self.head.apply(
jax.tree_util.tree_map(lambda param: param[idx_head], params["head_params"]), feature
) # output (n_quantiles, n_actions)
class AtariiIQN:
def __init__(self, n_actions: int, idx_head: int) -> None:
self.network = AtariSharediIQNNet(n_actions)
self.idx_head = idx_head
self.n_quantiles_policy = 32
@partial(jax.jit, static_argnames="self")
def best_action(self, params: FrozenDict, state: jnp.ndarray, key: jax.random.PRNGKeyArray) -> jnp.int8:
# output (n_quantiles, n_actions)
q_quantiles = self.network.apply(params, self.idx_head, scale(state), key, self.n_quantiles_policy)
q_values = jnp.mean(q_quantiles, axis=0)
return jnp.argmax(q_values).astype(jnp.int8)