def rbf_kernel(X):
XY = T.dot(X, X.T)
x2 = T.sum(X**2, axis=1).dimshuffle(0, 'x')
X2e = T.repeat(x2, X.shape[0], axis=1)
H = X2e + X2e.T - 2. * XY
V = H.flatten()
# median distance
h = T.switch(T.eq((V.shape[0] % 2), 0),
# if even vector
T.mean(T.sort(V)[ ((V.shape[0] // 2) - 1) : ((V.shape[0] // 2) + 1) ]),
# if odd vector
T.sort(V)[V.shape[0] // 2])
h = T.sqrt(.5 * h / T.log(H.shape[0].astype('float32') + 1.))
# compute the rbf kernel
kxy = T.exp(-H / (h ** 2) / 2.0)
dxkxy = -T.dot(kxy, X)
sumkxy = T.sum(kxy, axis=1).dimshuffle(0, 'x')
dxkxy = T.add(dxkxy, T.mul(X, sumkxy)) / (h ** 2)
return kxy, dxkxy
评论列表
文章目录