def _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=False):
mmd2, var_est = _mmd2_and_variance(K_XX, K_XY, K_YY, const_diagonal=const_diagonal, biased=biased)
loss = mmd2 / torch.sqrt(torch.clamp(var_est, min=min_var_est))
return loss, mmd2, var_est
评论列表
文章目录