import torch.nn.functional as F def norm_mse_loss(x0, x1): x0 = F.normalize(x0) x1 = F.normalize(x1) return 2 - 2 * (x0 * x1).sum(dim=-1).mean()