cam-ml-yog-v0 / test_python_net.py
tztsai's picture
Upload models
ec86bf7 verified
raw
history blame
No virus
1.89 kB
"""A smoke test for the ANN model.
This test checks that the model can be loaded from a weights file in both pt format and
netcdf format and that they produce the expected output when given an input of all ones.
This ensures that it is equivalent to the Fortran NN model.
"""
import os
from pathlib import Path
import torch
import numpy as np
from models import ANN, load_from_netcdf_params
os.chdir(Path(__file__).parent)
expected = np.loadtxt("nn_ones.txt").astype(np.float32)
# nn_ones.txt is the output of the Fortran NN model given an input of all ones.
model1 = ANN().load("nn_state.pt") # load from the pytorch weights
model2 = load_from_netcdf_params(
"qobsTTFFFFFTF30FFTFTF30TTFTFTFFF80FFTFTTF2699FFFF_X01_no_qp_no_adv_"
"surf_F_Tin_qin_disteq_O_Trad_rest_Tadv_qadv_qout_qsed_RESCALED_7epochs"
"_no_drop_REAL_NN_layers5in61out148_BN_F_te70.nc"
) # load from the NetCDF weights of the pretrained Fortran NN model
# file created at https://github.com/yaniyuval/Neural_nework_parameterization/blob/f81f5f695297888f0bd1e0e61524590b4566bf03/NN_training/src/ml_train_nn.py#L417 # pylint: disable=line-too-long
# (which the naming scheme integrating information about the training setup, see e.g., https://github.com/yaniyuval/Neural_nework_parameterization/blob/f81f5f695297888f0bd1e0e61524590b4566bf03/NN_training/src/ml_train_nn.py#L263-L265) # pylint: disable=line-too-long
# This Neural Net can be found at https://github.com/yaniyuval/Neural_nework_parameterization/tree/f81f5f695297888f0bd1e0e61524590b4566bf03/NNs # pylint: disable=line-too-long
x = torch.ones(61)
actual1 = model1.forward(x).detach().numpy()
actual2 = model2.forward(x).detach().numpy()
assert np.all(actual1 == actual2)
assert np.allclose(expected, actual1, atol=3e-8, rtol=2e-6)
# Values of atol and rtol are chosen to be the lowest that still pass the test.
print("Smoke tests passed")