cam-ml-yog-v0 / models.py
tztsai
rename Fortran NN file
d09d703
raw
history blame contribute delete
No virus
6.01 kB
"""Neural network architectures."""
from typing import Optional
import netCDF4 as nc # type: ignore
import torch
from torch import nn, Tensor
class ANN(nn.Sequential):
"""Model used in the paper.
Paper: https://doi.org/10.1029/2020GL091363
Parameters
----------
n_in : int
Number of input features.
n_out : int
Number of output features.
n_layers : int
Number of layers.
neurons : int
The number of neurons in the hidden layers.
dropout : float
The dropout probability to apply in the hidden layers.
device : str
The device to put the model on.
features_mean : ndarray
The mean of the input features.
features_std : ndarray
The standard deviation of the input features.
outputs_mean : ndarray
The mean of the output features.
outputs_std : ndarray
The standard deviation of the output features.
output_groups : ndarray
The number of output features in each group of the ouput.
Notes
-----
If you are doing inference, always remember to put the model in eval model,
by using ``model.eval()``, so the dropout layers are turned off.
"""
def __init__( # pylint: disable=too-many-arguments,too-many-locals
self,
n_in: int = 61,
n_out: int = 148,
n_layers: int = 5,
neurons: int = 128,
dropout: float = 0.0,
device: str = "cpu",
features_mean: Optional[Tensor] = None,
features_std: Optional[Tensor] = None,
outputs_mean: Optional[Tensor] = None,
outputs_std: Optional[Tensor] = None,
output_groups: Optional[list] = None,
):
"""Initialize the ANN model."""
dims = [n_in] + [neurons] * (n_layers - 1) + [n_out]
layers = []
for i in range(n_layers):
layers.append(nn.Linear(dims[i], dims[i + 1]))
if i < n_layers - 1:
layers.append(nn.ReLU()) # type: ignore
layers.append(nn.Dropout(dropout)) # type: ignore
super().__init__(*layers)
fmean = fstd = omean = ostd = None
if features_mean is not None:
assert features_std is not None
assert len(features_mean) == len(features_std)
fmean = torch.tensor(features_mean)
fstd = torch.tensor(features_std)
if outputs_mean is not None:
assert outputs_std is not None
assert len(outputs_mean) == len(outputs_std)
if output_groups is None:
omean = torch.tensor(outputs_mean)
ostd = torch.tensor(outputs_std)
else:
assert len(output_groups) == len(outputs_mean)
omean = torch.tensor(
[x for x, g in zip(outputs_mean, output_groups) for _ in range(g)]
)
ostd = torch.tensor(
[x for x, g in zip(outputs_std, output_groups) for _ in range(g)]
)
self.register_buffer("features_mean", fmean)
self.register_buffer("features_std", fstd)
self.register_buffer("outputs_mean", omean)
self.register_buffer("outputs_std", ostd)
self.to(torch.device(device))
def forward(self, input: Tensor): # pylint: disable=redefined-builtin
"""Pass the input through the model.
Override the forward method of nn.Sequential to add normalization
to the input and denormalization to the output.
Parameters
----------
input : Tensor
A mini-batch of inputs.
Returns
-------
Tensor
The model output.
"""
if self.features_mean is not None:
input = (input - self.features_mean) / self.features_std
# pass the input through the layers using nn.Sequential.forward
output = super().forward(input)
if self.outputs_mean is not None:
output = output * self.outputs_std + self.outputs_mean
return output
def load(self, path: str) -> "ANN":
"""Load the model from a checkpoint.
Parameters
----------
path : str
The path to the checkpoint.
"""
state = torch.load(path)
for key in ["features_mean", "features_std", "outputs_mean", "outputs_std"]:
if key in state and getattr(self, key) is None:
setattr(self, key, state[key])
self.load_state_dict(state)
return self
def save(self, path: str):
"""Save the model to a checkpoint.
Parameters
----------
path : str
The path to save the checkpoint to.
"""
torch.save(self.state_dict(), path)
def load_from_netcdf_params(nc_file: str, dtype: str = "float32") -> ANN:
"""Load the model with weights and biases from the netcdf file.
Parameters
----------
nc_file : str
The netcdf file containing the parameters.
dtype : str
The data type to cast the parameters to.
"""
data_set = nc.Dataset(nc_file) # pylint: disable=no-member
model = ANN(
features_mean=data_set["fscale_mean"][:].astype(dtype),
features_std=data_set["fscale_stnd"][:].astype(dtype),
outputs_mean=data_set["oscale_mean"][:].astype(dtype),
outputs_std=data_set["oscale_stnd"][:].astype(dtype),
output_groups=[30, 29, 29, 30, 30],
)
for i, layer in enumerate(l for l in model.modules() if isinstance(l, nn.Linear)):
layer.weight.data = torch.tensor(data_set[f"w{i+1}"][:].astype(dtype))
layer.bias.data = torch.tensor(data_set[f"b{i+1}"][:].astype(dtype))
return model
if __name__ == "__main__":
# Load the model from the netcdf file and save it to a checkpoint.
net = load_from_netcdf_params(
"NN_weights_YOG_convection.nc"
)
net.save("nn_state.pt")
print("Model saved to nn_state.pt")