File size: 695 Bytes
251016c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
# Copyright (c) 2024, SliceX AI, Inc. All Rights Reserved.
from prettytable import PrettyTable
def count_parameters(model):
"""Count the number of parameters in the model."""
table = PrettyTable(["Modules", "Parameters"])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad: continue
params = parameter.numel()
table.add_row([name, params])
total_params+=params
print(table)
print(f"Total Trainable Params: {total_params}")
return total_params
def batchify(lst, n):
"""Divide a list into chunks of size n."""
return [lst[i:i + n] for i in range(0, len(lst), n)]
|