|
def count_parameters(model, verbose=True): |
|
"""Count number of parameters in PyTorch model, |
|
References: https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/7. |
|
|
|
from utils.utils import count_parameters |
|
count_parameters(model) |
|
import sys |
|
sys.exit(1) |
|
""" |
|
n_all = sum(p.numel() for p in model.parameters()) |
|
n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
if verbose: |
|
print("Parameter Count: all {:,d}; trainable {:,d}".format(n_all, n_trainable)) |
|
return n_all, n_trainable |
|
|
|
|