vfe_models.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:geepee 作者: thangbui 项目源码 文件源码
def init_hypers(self, y_train):
        """Summary

        Returns:
            TYPE: Description

        Args:
            y_train (TYPE): Description
        """
        N = self.N
        M = self.M
        Din = self.Din
        Dout = self.Dout
        x_train = self.x_train
        if N < 10000:
            centroids, label = kmeans2(x_train, M, minit='points')
        else:
            randind = np.random.permutation(N)
            centroids = x_train[randind[0:M], :]
        zu = centroids
        if N < 10000:
            X1 = np.copy(x_train)
        else:
            randind = np.random.permutation(N)
            X1 = X[randind[:5000], :]
        x_dist = cdist(X1, X1, 'euclidean')
        triu_ind = np.triu_indices(N)
        ls = np.zeros((Din, ))
        d2imed = np.median(x_dist[triu_ind])
        for i in range(Din):
            ls[i] = 2 * np.log(d2imed + 1e-16)
        sf = np.log(np.array([0.5]))

        params = dict()
        params['sf'] = sf
        params['ls'] = ls
        params['zu'] = zu
        params['sn'] = np.log(0.01)
        return params
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号