|
import numpy as np |
|
import torch |
|
|
|
|
|
class BaseModule(torch.nn.Module): |
|
def __init__(self): |
|
super(BaseModule, self).__init__() |
|
|
|
@property |
|
def nparams(self): |
|
""" |
|
Returns number of trainable parameters of the module. |
|
""" |
|
num_params = 0 |
|
for name, param in self.named_parameters(): |
|
if param.requires_grad: |
|
num_params += np.prod(param.detach().cpu().numpy().shape) |
|
return num_params |
|
|
|
def relocate_input(self, x: list): |
|
""" |
|
Relocates provided tensors to the same device set for the module. |
|
""" |
|
device = next(self.parameters()).device |
|
for i in range(len(x)): |
|
if isinstance(x[i], torch.Tensor) and x[i].device != device: |
|
x[i] = x[i].to(device) |
|
return x |
|
|