|
|
|
|
|
from prettytable import PrettyTable |
|
|
|
def count_parameters(model): |
|
"""Count the number of parameters in the model.""" |
|
table = PrettyTable(["Modules", "Parameters"]) |
|
total_params = 0 |
|
|
|
for name, parameter in model.named_parameters(): |
|
if not parameter.requires_grad: continue |
|
params = parameter.numel() |
|
table.add_row([name, params]) |
|
total_params+=params |
|
|
|
print(table) |
|
print(f"Total Trainable Params: {total_params}") |
|
|
|
return total_params |
|
|
|
|
|
def batchify(lst, n): |
|
"""Divide a list into chunks of size n.""" |
|
return [lst[i:i + n] for i in range(0, len(lst), n)] |
|
|
|
|