def init_hypers(self, x_train=None, key_suffix=''):
"""Summary
Args:
x_train (None, optional): Description
key_suffix (str, optional): Description
Returns:
TYPE: Description
"""
# dict to hold hypers, inducing points and parameters of q(U)
N = self.N
M = self.M
Din = self.Din
Dout = self.Dout
if x_train is None:
ls = np.log(np.ones((Din, )) + 0.1 * np.random.rand(Din, ))
sf = np.log(np.array([1]))
zu = np.tile(np.linspace(-1, 1, M).reshape((M, 1)), (1, Din))
else:
if N < 10000:
centroids, label = kmeans2(x_train, M, minit='points')
else:
randind = np.random.permutation(N)
centroids = x_train[randind[0:M], :]
zu = centroids
if N < 10000:
X1 = np.copy(x_train)
else:
randind = np.random.permutation(N)
X1 = X[randind[:5000], :]
x_dist = cdist(X1, X1, 'euclidean')
triu_ind = np.triu_indices(N)
ls = np.zeros((Din, ))
d2imed = np.median(x_dist[triu_ind])
for i in range(Din):
ls[i] = np.log(d2imed + 1e-16)
sf = np.log(np.array([0.5]))
params = dict()
params['sf' + key_suffix] = sf
params['ls' + key_suffix] = ls
params['zu' + key_suffix] = zu
return params
评论列表
文章目录