Anonymous-123's picture
Add application file
ec0fdfd
raw
history blame
396 Bytes
from torch.nn import functional as F
def d_clip_loss(x, y, use_cosine=False):
x = F.normalize(x, dim=-1)
y = F.normalize(y, dim=-1)
if use_cosine:
distance = 1 - (x @ y.t()).squeeze()
else:
distance = (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
return distance
def range_loss(input):
return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])