|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Memory-efficient MMD implementation in JAX.""" |
|
|
|
import torch |
|
|
|
|
|
|
|
_SIGMA = 10 |
|
|
|
|
|
_SCALE = 1000 |
|
|
|
|
|
def mmd(x, y): |
|
"""Memory-efficient MMD implementation in JAX. |
|
|
|
This implements the minimum-variance/biased version of the estimator described |
|
in Eq.(5) of |
|
https://jmlr.csail.mit.edu/papers/volume13/gretton12a/gretton12a.pdf. |
|
As described in Lemma 6's proof in that paper, the unbiased estimate and the |
|
minimum-variance estimate for MMD are almost identical. |
|
|
|
Note that the first invocation of this function will be considerably slow due |
|
to JAX JIT compilation. |
|
|
|
Args: |
|
x: The first set of embeddings of shape (n, embedding_dim). |
|
y: The second set of embeddings of shape (n, embedding_dim). |
|
|
|
Returns: |
|
The MMD distance between x and y embedding sets. |
|
""" |
|
x = torch.from_numpy(x) |
|
y = torch.from_numpy(y) |
|
|
|
x_sqnorms = torch.diag(torch.matmul(x, x.T)) |
|
y_sqnorms = torch.diag(torch.matmul(y, y.T)) |
|
|
|
gamma = 1 / (2 * _SIGMA**2) |
|
k_xx = torch.mean( |
|
torch.exp(-gamma * (-2 * torch.matmul(x, x.T) + torch.unsqueeze(x_sqnorms, 1) + torch.unsqueeze(x_sqnorms, 0))) |
|
) |
|
k_xy = torch.mean( |
|
torch.exp(-gamma * (-2 * torch.matmul(x, y.T) + torch.unsqueeze(x_sqnorms, 1) + torch.unsqueeze(y_sqnorms, 0))) |
|
) |
|
k_yy = torch.mean( |
|
torch.exp(-gamma * (-2 * torch.matmul(y, y.T) + torch.unsqueeze(y_sqnorms, 1) + torch.unsqueeze(y_sqnorms, 0))) |
|
) |
|
|
|
return _SCALE * (k_xx + k_yy - 2 * k_xy) |
|
|