File size: 1,208 Bytes
8e63d2a
 
79943a9
 
 
 
 
 
 
 
 
 
8e63d2a
79943a9
 
 
 
 
8e63d2a
 
 
79943a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from typing import Sequence, Type

import numpy as np
import torch.nn as nn


def mlp(
    layer_sizes: Sequence[int],
    activation: Type[nn.Module],
    output_activation: Type[nn.Module] = nn.Identity,
    init_layers_orthogonal: bool = False,
    final_layer_gain: float = np.sqrt(2),
    hidden_layer_gain: float = np.sqrt(2),
) -> nn.Module:
    layers = []
    for i in range(len(layer_sizes) - 2):
        layers.append(
            layer_init(
                nn.Linear(layer_sizes[i], layer_sizes[i + 1]),
                init_layers_orthogonal,
                std=hidden_layer_gain,
            )
        )
        layers.append(activation())
    layers.append(
        layer_init(
            nn.Linear(layer_sizes[-2], layer_sizes[-1]),
            init_layers_orthogonal,
            std=final_layer_gain,
        )
    )
    layers.append(output_activation())
    return nn.Sequential(*layers)


def layer_init(
    layer: nn.Module, init_layers_orthogonal: bool, std: float = np.sqrt(2)
) -> nn.Module:
    if not init_layers_orthogonal:
        return layer
    nn.init.orthogonal_(layer.weight, std)  # type: ignore
    nn.init.constant_(layer.bias, 0.0)  # type: ignore
    return layer