|
import numpy as np |
|
import torch |
|
import ot |
|
from otfuncs import ( |
|
compute_distance_matrix_cosine, |
|
compute_distance_matrix_l2, |
|
compute_weights_norm, |
|
compute_weights_uniform, |
|
min_max_scaling |
|
) |
|
|
|
class Aligner: |
|
def __init__(self, ot_type, sinkhorn, dist_type, weight_type, distortion, thresh, tau, **kwargs): |
|
self.ot_type = ot_type |
|
self.sinkhorn = sinkhorn |
|
self.dist_type = dist_type |
|
self.weight_type = weight_type |
|
self.distotion = distortion |
|
self.thresh = thresh |
|
self.tau = tau |
|
self.epsilon = 0.1 |
|
self.stopThr = 1e-6 |
|
self.numItermax = 1000 |
|
self.div_type = kwargs['div_type'] |
|
|
|
self.dist_func = compute_distance_matrix_cosine if dist_type == 'cos' else compute_distance_matrix_l2 |
|
if weight_type == 'uniform': |
|
self.weight_func = compute_weights_uniform |
|
else: |
|
self.weight_func = compute_weights_norm |
|
|
|
def compute_alignment_matrixes(self, s1_word_embeddigs, s2_word_embeddigs): |
|
P, Cost, log, similarity_matrix = self.compute_optimal_transport(s1_word_embeddigs, s2_word_embeddigs) |
|
print(log.keys()) |
|
if torch.is_tensor(P): |
|
P = P.to('cpu').numpy() |
|
loss = log.get('cost', 'NotImplemented') |
|
|
|
return P, Cost, loss, similarity_matrix |
|
|
|
def compute_optimal_transport(self, s1_word_embeddigs, s2_word_embeddigs): |
|
s1_word_embeddigs = s1_word_embeddigs.to(torch.float64) |
|
s2_word_embeddigs = s2_word_embeddigs.to(torch.float64) |
|
|
|
C, similarity_matrix = self.dist_func(s1_word_embeddigs, s2_word_embeddigs, self.distotion) |
|
s1_weights, s2_weights = self.weight_func(s1_word_embeddigs, s2_word_embeddigs) |
|
|
|
if self.ot_type == 'ot': |
|
s1_weights = s1_weights / s1_weights.sum() |
|
s2_weights = s2_weights / s2_weights.sum() |
|
s1_weights, s2_weights, C = self.convert_to_numpy(s1_weights, s2_weights, C) |
|
|
|
if self.sinkhorn: |
|
P, log = ot.bregman.sinkhorn_log( |
|
s1_weights, s2_weights, C, |
|
reg=self.epsilon, stopThr=self.stopThr, |
|
numItermax=self.numItermax, log=True |
|
) |
|
else: |
|
P, log = ot.emd(s1_weights, s2_weights, C, log=True) |
|
|
|
P = min_max_scaling(P) |
|
|
|
elif self.ot_type == 'pot': |
|
s1_weights, s2_weights, C = self.convert_to_numpy(s1_weights, s2_weights, C) |
|
m = np.min((np.sum(s1_weights), np.sum(s2_weights))) * self.tau |
|
|
|
if self.sinkhorn: |
|
P, log = ot.partial.entropic_partial_wasserstein( |
|
s1_weights, s2_weights, C, |
|
reg=self.epsilon, |
|
m=m, stopThr=self.stopThr, numItermax=self.numItermax, log=True |
|
) |
|
else: |
|
|
|
P, log = ot.partial.partial_wasserstein(s1_weights, s2_weights, C, m=m, log=True) |
|
|
|
P = min_max_scaling(P) |
|
|
|
elif 'uot' in self.ot_type: |
|
tau = self.tau |
|
|
|
if self.ot_type == 'uot': |
|
P, log = ot.unbalanced.sinkhorn_stabilized_unbalanced( |
|
s1_weights, s2_weights, C, reg=self.epsilon, reg_m=tau, |
|
stopThr=self.stopThr, numItermax=self.numItermax, log=True |
|
) |
|
elif self.ot_type == 'uot-mm': |
|
P, log = ot.unbalanced.mm_unbalanced( |
|
s1_weights, s2_weights, C, reg_m=tau, div=self.div_type, |
|
stopThr=self.stopThr, numItermax=self.numItermax, log=True |
|
) |
|
|
|
P = min_max_scaling(P) |
|
|
|
elif self.ot_type == 'none': |
|
P = 1 - C |
|
|
|
return P, C, log, similarity_matrix |
|
|
|
def convert_to_numpy(self, s1_weights, s2_weights, C): |
|
if torch.is_tensor(s1_weights): |
|
s1_weights = s1_weights.to('cpu').numpy() |
|
s2_weights = s2_weights.to('cpu').numpy() |
|
if torch.is_tensor(C): |
|
C = C.to('cpu').numpy() |
|
|
|
return s1_weights, s2_weights, C |
|
|