test_gp_signals.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:enterprise 作者: nanograv 项目源码 文件源码
def test_kernel_backend(self):
        # set up signal parameter
        selection = Selection(selections.by_backend)
        log10_sigma = parameter.Uniform(-10, -5)
        log10_lam = parameter.Uniform(np.log10(86400), np.log10(1500*86400))
        basis = create_quant_matrix(dt=7*86400)
        prior = se_kernel(log10_sigma=log10_sigma, log10_lam=log10_lam)

        se = gs.BasisGP(prior, basis, selection=selection, name='se')
        sem = se(self.psr)

        # parameters
        log10_sigmas = [-7, -6, -6.4, -8.5]
        log10_lams = [8.3, 7.4, 6.8, 5.6]
        params = {'B1855+09_se_430_ASP_log10_lam': log10_lams[0],
                  'B1855+09_se_430_ASP_log10_sigma': log10_sigmas[0],
                  'B1855+09_se_430_PUPPI_log10_lam': log10_lams[1],
                  'B1855+09_se_430_PUPPI_log10_sigma': log10_sigmas[1],
                  'B1855+09_se_L-wide_ASP_log10_lam': log10_lams[2],
                  'B1855+09_se_L-wide_ASP_log10_sigma': log10_sigmas[2],
                  'B1855+09_se_L-wide_PUPPI_log10_lam': log10_lams[3],
                  'B1855+09_se_L-wide_PUPPI_log10_sigma': log10_sigmas[3]}

        # get the basis
        bflags = self.psr.backend_flags
        Fmats, fs, phis = [], [], []
        for ct, flag in enumerate(np.unique(bflags)):
            mask = bflags == flag
            U, avetoas = create_quant_matrix(self.psr.toas[mask], dt=7*86400)
            Fmats.append(U)
            fs.append(avetoas)
            phis.append(se_kernel(avetoas, log10_sigma=log10_sigmas[ct],
                                  log10_lam=log10_lams[ct]))

        nf = sum(F.shape[1] for F in Fmats)
        U = np.zeros((len(self.psr.toas), nf))
        K = sl.block_diag(*phis)
        Kinv = np.linalg.inv(K)
        nftot = 0
        for ct, flag in enumerate(np.unique(bflags)):
            mask = bflags == flag
            nn = Fmats[ct].shape[1]
            U[mask, nftot:nn+nftot] = Fmats[ct]
            nftot += nn

        msg = 'Kernel basis incorrect for backend signal.'
        assert np.allclose(U, sem.get_basis(params)), msg

        # spectrum test
        msg = 'Kernel incorrect for backend signal.'
        assert np.allclose(sem.get_phi(params), K), msg

        # inverse spectrum test
        msg = 'Kernel inverse incorrect for backend signal.'
        assert np.allclose(sem.get_phiinv(params), Kinv), msg
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号