def _mix_rbf_kernel(X, Y, sigma_list):
assert(X.size(0) == Y.size(0))
m = X.size(0)
Z = torch.cat((X, Y), 0)
ZZT = torch.mm(Z, Z.t())
diag_ZZT = torch.diag(ZZT).unsqueeze(1)
Z_norm_sqr = diag_ZZT.expand_as(ZZT)
exponent = Z_norm_sqr - 2 * ZZT + Z_norm_sqr.t()
K = 0.0
for sigma in sigma_list:
gamma = 1.0 / (2 * sigma**2)
K += torch.exp(-gamma * exponent)
return K[:m, :m], K[:m, m:], K[m:, m:], len(sigma_list)
评论列表
文章目录