|
import torch |
|
import torch.nn.functional as F |
|
from ot.backend import get_backend |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
def compute_distance_matrix_cosine( |
|
s1_word_embeddigs, s2_word_embeddigs, distortion_ratio |
|
): |
|
sim_matrix = ( |
|
torch.matmul(F.normalize(s1_word_embeddigs), F.normalize(s2_word_embeddigs).t()) |
|
+ 1.0 |
|
) / 2 |
|
C = apply_distortion(sim_matrix, distortion_ratio) |
|
C = min_max_scaling(C) |
|
C = 1.0 - C |
|
|
|
return C, sim_matrix |
|
|
|
|
|
def compute_distance_matrix_l2(s1_word_embeddigs, s2_word_embeddigs, distortion_ratio): |
|
C = torch.cdist(s1_word_embeddigs, s2_word_embeddigs, p=2) |
|
C = min_max_scaling(C) |
|
C = 1.0 - C |
|
C = apply_distortion(C, distortion_ratio) |
|
C = min_max_scaling(C) |
|
C = 1.0 - C |
|
|
|
return C |
|
|
|
|
|
def apply_distortion(sim_matrix, ratio): |
|
shape = sim_matrix.shape |
|
if (shape[0] < 2 or shape[1] < 2) or ratio == 0.0: |
|
return sim_matrix |
|
|
|
pos_x = torch.tensor( |
|
[[y / float(shape[1] - 1) for y in range(shape[1])] for x in range(shape[0])], |
|
device=device, |
|
) |
|
pos_y = torch.tensor( |
|
[[x / float(shape[0] - 1) for x in range(shape[0])] for y in range(shape[1])], |
|
device=device, |
|
) |
|
relative_distance = (pos_x - pos_y.T) ** 2 |
|
distortion_mask = 1.0 - relative_distance * ratio |
|
|
|
sim_matrix = torch.mul(sim_matrix, distortion_mask) |
|
|
|
return sim_matrix |
|
|
|
|
|
def compute_weights_norm(s1_word_embeddigs, s2_word_embeddigs): |
|
s1_weights = torch.norm(s1_word_embeddigs, dim=1) |
|
s2_weights = torch.norm(s2_word_embeddigs, dim=1) |
|
return s1_weights, s2_weights |
|
|
|
|
|
def compute_weights_uniform(s1_word_embeddigs, s2_word_embeddigs): |
|
s1_weights = torch.ones( |
|
s1_word_embeddigs.shape[0], dtype=torch.float64, device=device |
|
) |
|
s2_weights = torch.ones( |
|
s2_word_embeddigs.shape[0], dtype=torch.float64, device=device |
|
) |
|
|
|
|
|
|
|
|
|
|
|
return s1_weights, s2_weights |
|
|
|
|
|
def min_max_scaling(C): |
|
eps = 1e-10 |
|
|
|
nx = get_backend(C) |
|
C_min = nx.min(C) |
|
C_max = nx.max(C) |
|
C = (C - C_min + eps) / (C_max - C_min + eps) |
|
return C |
|
|