File size: 2,283 Bytes
d344462 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Memory-efficient MMD implementation in JAX."""
import torch
# The bandwidth parameter for the Gaussian RBF kernel. See the paper for more
# details.
_SIGMA = 10
# The following is used to make the metric more human readable. See the paper
# for more details.
_SCALE = 1000
def mmd(x, y):
"""Memory-efficient MMD implementation in JAX.
This implements the minimum-variance/biased version of the estimator described
in Eq.(5) of
https://jmlr.csail.mit.edu/papers/volume13/gretton12a/gretton12a.pdf.
As described in Lemma 6's proof in that paper, the unbiased estimate and the
minimum-variance estimate for MMD are almost identical.
Note that the first invocation of this function will be considerably slow due
to JAX JIT compilation.
Args:
x: The first set of embeddings of shape (n, embedding_dim).
y: The second set of embeddings of shape (n, embedding_dim).
Returns:
The MMD distance between x and y embedding sets.
"""
x = torch.from_numpy(x)
y = torch.from_numpy(y)
x_sqnorms = torch.diag(torch.matmul(x, x.T))
y_sqnorms = torch.diag(torch.matmul(y, y.T))
gamma = 1 / (2 * _SIGMA**2)
k_xx = torch.mean(
torch.exp(-gamma * (-2 * torch.matmul(x, x.T) + torch.unsqueeze(x_sqnorms, 1) + torch.unsqueeze(x_sqnorms, 0)))
)
k_xy = torch.mean(
torch.exp(-gamma * (-2 * torch.matmul(x, y.T) + torch.unsqueeze(x_sqnorms, 1) + torch.unsqueeze(y_sqnorms, 0)))
)
k_yy = torch.mean(
torch.exp(-gamma * (-2 * torch.matmul(y, y.T) + torch.unsqueeze(y_sqnorms, 1) + torch.unsqueeze(y_sqnorms, 0)))
)
return _SCALE * (k_xx + k_yy - 2 * k_xy)
|