UOT / otfuncs.py
4kasha
fix
f31ab4f
import torch
import torch.nn.functional as F
from ot.backend import get_backend
device = "cuda" if torch.cuda.is_available() else "cpu"
def compute_distance_matrix_cosine(
s1_word_embeddigs, s2_word_embeddigs, distortion_ratio
):
sim_matrix = (
torch.matmul(F.normalize(s1_word_embeddigs), F.normalize(s2_word_embeddigs).t())
+ 1.0
) / 2 # Range 0-1
C = apply_distortion(sim_matrix, distortion_ratio)
C = min_max_scaling(C) # Range 0-1
C = 1.0 - C # Convert to distance
return C, sim_matrix
def compute_distance_matrix_l2(s1_word_embeddigs, s2_word_embeddigs, distortion_ratio):
C = torch.cdist(s1_word_embeddigs, s2_word_embeddigs, p=2)
C = min_max_scaling(C) # Range 0-1
C = 1.0 - C # Convert to similarity
C = apply_distortion(C, distortion_ratio)
C = min_max_scaling(C) # Range 0-1
C = 1.0 - C # Convert to distance
return C
def apply_distortion(sim_matrix, ratio):
shape = sim_matrix.shape
if (shape[0] < 2 or shape[1] < 2) or ratio == 0.0:
return sim_matrix
pos_x = torch.tensor(
[[y / float(shape[1] - 1) for y in range(shape[1])] for x in range(shape[0])],
device=device,
)
pos_y = torch.tensor(
[[x / float(shape[0] - 1) for x in range(shape[0])] for y in range(shape[1])],
device=device,
)
relative_distance = (pos_x - pos_y.T) ** 2
distortion_mask = 1.0 - relative_distance * ratio
sim_matrix = torch.mul(sim_matrix, distortion_mask)
return sim_matrix
def compute_weights_norm(s1_word_embeddigs, s2_word_embeddigs):
s1_weights = torch.norm(s1_word_embeddigs, dim=1)
s2_weights = torch.norm(s2_word_embeddigs, dim=1)
return s1_weights, s2_weights
def compute_weights_uniform(s1_word_embeddigs, s2_word_embeddigs):
s1_weights = torch.ones(
s1_word_embeddigs.shape[0], dtype=torch.float64, device=device
)
s2_weights = torch.ones(
s2_word_embeddigs.shape[0], dtype=torch.float64, device=device
)
# # Uniform weights to make L2 norm=1
# s1_weights /= torch.linalg.norm(s1_weights)
# s2_weights /= torch.linalg.norm(s2_weights)
return s1_weights, s2_weights
def min_max_scaling(C):
eps = 1e-10
# Min-max scaling for stabilization
nx = get_backend(C)
C_min = nx.min(C)
C_max = nx.max(C)
C = (C - C_min + eps) / (C_max - C_min + eps)
return C