File size: 6,624 Bytes
5b9b09f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from abc import ABC, abstractmethod
from gym.spaces import Box, Discrete
from stable_baselines3.common.preprocessing import get_flattened_obs_dim
from typing import Dict, Optional, Type
from shared.module.module import layer_init
class CnnFeatureExtractor(nn.Module, ABC):
@abstractmethod
def __init__(
self,
in_channels: int,
activation: Type[nn.Module] = nn.ReLU,
init_layers_orthogonal: Optional[bool] = None,
) -> None:
super().__init__()
class NatureCnn(CnnFeatureExtractor):
"""
CNN from DQN Nature paper: Mnih, Volodymyr, et al.
"Human-level control through deep reinforcement learning."
Nature 518.7540 (2015): 529-533.
"""
def __init__(
self,
in_channels: int,
activation: Type[nn.Module] = nn.ReLU,
init_layers_orthogonal: Optional[bool] = None,
) -> None:
if init_layers_orthogonal is None:
init_layers_orthogonal = True
super().__init__(in_channels, activation, init_layers_orthogonal)
self.cnn = nn.Sequential(
layer_init(
nn.Conv2d(in_channels, 32, kernel_size=8, stride=4),
init_layers_orthogonal,
),
activation(),
layer_init(
nn.Conv2d(32, 64, kernel_size=4, stride=2),
init_layers_orthogonal,
),
activation(),
layer_init(
nn.Conv2d(64, 64, kernel_size=3, stride=1),
init_layers_orthogonal,
),
activation(),
nn.Flatten(),
)
def forward(self, obs: torch.Tensor) -> torch.Tensor:
return self.cnn(obs)
class ResidualBlock(nn.Module):
def __init__(
self,
channels: int,
activation: Type[nn.Module] = nn.ReLU,
init_layers_orthogonal: bool = False,
) -> None:
super().__init__()
self.residual = nn.Sequential(
activation(),
layer_init(
nn.Conv2d(channels, channels, 3, padding=1), init_layers_orthogonal
),
activation(),
layer_init(
nn.Conv2d(channels, channels, 3, padding=1), init_layers_orthogonal
),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.residual(x)
class ConvSequence(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
activation: Type[nn.Module] = nn.ReLU,
init_layers_orthogonal: bool = False,
) -> None:
super().__init__()
self.seq = nn.Sequential(
layer_init(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
init_layers_orthogonal,
),
nn.MaxPool2d(3, stride=2, padding=1),
ResidualBlock(out_channels, activation, init_layers_orthogonal),
ResidualBlock(out_channels, activation, init_layers_orthogonal),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.seq(x)
class ImpalaCnn(CnnFeatureExtractor):
"""
IMPALA-style CNN architecture
"""
def __init__(
self,
in_channels: int,
activation: Type[nn.Module] = nn.ReLU,
init_layers_orthogonal: Optional[bool] = None,
) -> None:
if init_layers_orthogonal is None:
init_layers_orthogonal = False
super().__init__(in_channels, activation, init_layers_orthogonal)
sequences = []
for out_channels in [16, 32, 32]:
sequences.append(
ConvSequence(
in_channels, out_channels, activation, init_layers_orthogonal
)
)
in_channels = out_channels
sequences.extend(
[
activation(),
nn.Flatten(),
]
)
self.seq = nn.Sequential(*sequences)
def forward(self, obs: torch.Tensor) -> torch.Tensor:
return self.seq(obs)
CNN_EXTRACTORS_BY_STYLE: Dict[str, Type[CnnFeatureExtractor]] = {
"nature": NatureCnn,
"impala": ImpalaCnn,
}
class FeatureExtractor(nn.Module):
def __init__(
self,
obs_space: gym.Space,
activation: Type[nn.Module],
init_layers_orthogonal: bool = False,
cnn_feature_dim: int = 512,
cnn_style: str = "nature",
cnn_layers_init_orthogonal: Optional[bool] = None,
) -> None:
super().__init__()
if isinstance(obs_space, Box):
# Conv2D: (channels, height, width)
if len(obs_space.shape) == 3:
cnn = CNN_EXTRACTORS_BY_STYLE[cnn_style](
obs_space.shape[0],
activation,
init_layers_orthogonal=cnn_layers_init_orthogonal,
)
def preprocess(obs: torch.Tensor) -> torch.Tensor:
if len(obs.shape) == 3:
obs = obs.unsqueeze(0)
return obs.float() / 255.0
with torch.no_grad():
cnn_out = cnn(preprocess(torch.as_tensor(obs_space.sample())))
self.preprocess = preprocess
self.feature_extractor = nn.Sequential(
cnn,
layer_init(
nn.Linear(cnn_out.shape[1], cnn_feature_dim),
init_layers_orthogonal,
),
activation(),
)
self.out_dim = cnn_feature_dim
elif len(obs_space.shape) == 1:
def preprocess(obs: torch.Tensor) -> torch.Tensor:
if len(obs.shape) == 1:
obs = obs.unsqueeze(0)
return obs.float()
self.preprocess = preprocess
self.feature_extractor = nn.Flatten()
self.out_dim = get_flattened_obs_dim(obs_space)
else:
raise ValueError(f"Unsupported observation space: {obs_space}")
elif isinstance(obs_space, Discrete):
self.preprocess = lambda x: F.one_hot(x, obs_space.n).float()
self.feature_extractor = nn.Flatten()
self.out_dim = obs_space.n
else:
raise NotImplementedError
def forward(self, obs: torch.Tensor) -> torch.Tensor:
if self.preprocess:
obs = self.preprocess(obs)
return self.feature_extractor(obs)
|