base_models.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:geepee 作者: thangbui 项目源码 文件源码
def get_hypers(self, key_suffix=''):
        """Summary

        Args:
            key_suffix (str, optional): Description

        Returns:
            TYPE: Description
        """
        params = {}
        M = self.M
        Din = self.Din
        Dout = self.Dout
        params['ls' + key_suffix] = self.ls
        params['sf' + key_suffix] = self.sf
        triu_ind = np.triu_indices(M)
        diag_ind = np.diag_indices(M)
        params_eta2 = self.theta_2
        params_eta1_R = np.zeros((Dout, M * (M + 1) / 2))
        params_zu_i = self.zu

        for d in range(Dout):
            Rd = np.copy(self.theta_1_R[d, :, :])
            Rd[diag_ind] = np.log(Rd[diag_ind])
            params_eta1_R[d, :] = np.copy(Rd[triu_ind])

        params['zu' + key_suffix] = self.zu
        params['eta1_R' + key_suffix] = params_eta1_R
        params['eta2' + key_suffix] = params_eta2
        return params
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号