from torch.nn import functional as F def d_clip_loss(x, y, use_cosine=False): x = F.normalize(x, dim=-1) y = F.normalize(y, dim=-1) if use_cosine: distance = 1 - (x @ y.t()).squeeze() else: distance = (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) return distance def range_loss(input): return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])