|
def set_requires_grad(nets, requires_grad=False): |
|
"""Set requies_grad for all the networks. |
|
|
|
Args: |
|
nets (nn.Module | list[nn.Module]): A list of networks or a single |
|
network. |
|
requires_grad (bool): Whether the networks require gradients or not |
|
""" |
|
if not isinstance(nets, list): |
|
nets = [nets] |
|
for net in nets: |
|
if net is not None: |
|
for param in net.parameters(): |
|
param.requires_grad = requires_grad |
|
|
|
|
|
def zero_module(module): |
|
""" |
|
Zero out the parameters of a module and return it. |
|
""" |
|
for p in module.parameters(): |
|
p.detach().zero_() |
|
return module |
|
|