def init_hypers(self, y_train):
"""Summary
Returns:
TYPE: Description
Args:
y_train (TYPE): Description
"""
N = self.N
M = self.M
Din = self.Din
Dout = self.Dout
x_train = self.x_train
if N < 10000:
centroids, label = kmeans2(x_train, M, minit='points')
else:
randind = np.random.permutation(N)
centroids = x_train[randind[0:M], :]
zu = centroids
if N < 10000:
X1 = np.copy(x_train)
else:
randind = np.random.permutation(N)
X1 = X[randind[:5000], :]
x_dist = cdist(X1, X1, 'euclidean')
triu_ind = np.triu_indices(N)
ls = np.zeros((Din, ))
d2imed = np.median(x_dist[triu_ind])
for i in range(Din):
ls[i] = 2 * np.log(d2imed + 1e-16)
sf = np.log(np.array([0.5]))
params = dict()
params['sf'] = sf
params['ls'] = ls
params['zu'] = zu
params['sn'] = np.log(0.01)
return params
评论列表
文章目录