def rbf_kernel(X0):
XY = T.dot(X0, X0.transpose())
x2 = T.reshape(T.sum(T.square(X0), axis=1), (X0.shape[0], 1))
X2e = T.repeat(x2, X0.shape[0], axis=1)
H = T.sub(T.add(X2e, X2e.transpose()), 2 * XY)
V = H.flatten()
# median distance
h = T.switch(T.eq((V.shape[0] % 2), 0),
# if even vector
T.mean(T.sort(V)[ ((V.shape[0] // 2) - 1) : ((V.shape[0] // 2) + 1) ]),
# if odd vector
T.sort(V)[V.shape[0] // 2])
h = T.sqrt(0.5 * h / T.log(X0.shape[0].astype('float32') + 1.0)) / 2.
Kxy = T.exp(-H / h ** 2 / 2.0)
neighbors = T.argsort(H, axis=1)[:, 1]
return Kxy, neighbors, h
评论列表
文章目录