# coding=utf-8 # Copyright 2024 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Memory-efficient MMD implementation in JAX.""" import torch # The bandwidth parameter for the Gaussian RBF kernel. See the paper for more # details. _SIGMA = 10 # The following is used to make the metric more human readable. See the paper # for more details. _SCALE = 1000 def mmd(x, y): """Memory-efficient MMD implementation in JAX. This implements the minimum-variance/biased version of the estimator described in Eq.(5) of https://jmlr.csail.mit.edu/papers/volume13/gretton12a/gretton12a.pdf. As described in Lemma 6's proof in that paper, the unbiased estimate and the minimum-variance estimate for MMD are almost identical. Note that the first invocation of this function will be considerably slow due to JAX JIT compilation. Args: x: The first set of embeddings of shape (n, embedding_dim). y: The second set of embeddings of shape (n, embedding_dim). Returns: The MMD distance between x and y embedding sets. """ x = torch.from_numpy(x) y = torch.from_numpy(y) x_sqnorms = torch.diag(torch.matmul(x, x.T)) y_sqnorms = torch.diag(torch.matmul(y, y.T)) gamma = 1 / (2 * _SIGMA**2) k_xx = torch.mean( torch.exp(-gamma * (-2 * torch.matmul(x, x.T) + torch.unsqueeze(x_sqnorms, 1) + torch.unsqueeze(x_sqnorms, 0))) ) k_xy = torch.mean( torch.exp(-gamma * (-2 * torch.matmul(x, y.T) + torch.unsqueeze(x_sqnorms, 1) + torch.unsqueeze(y_sqnorms, 0))) ) k_yy = torch.mean( torch.exp(-gamma * (-2 * torch.matmul(y, y.T) + torch.unsqueeze(y_sqnorms, 1) + torch.unsqueeze(y_sqnorms, 0))) ) return _SCALE * (k_xx + k_yy - 2 * k_xy)