import torch | |
def mean_activations(tensor): | |
"""Computes mean of activation maps tensor.""" | |
# squeeze to remove batch dimension | |
return torch.mean(tensor.detach().cpu(), dim=1).squeeze(dim=0) | |
def load_weights(model, weights): | |
"""Loads the weights of only the layers present in the given model.""" | |
pretrained_dict = torch.load(weights, map_location='cpu') | |
model_dict = model.state_dict() | |
pretrained_dict = {k: v for k, | |
v in pretrained_dict.items() if k in model_dict} | |
model_dict.update(pretrained_dict) | |
model.load_state_dict(model_dict) | |