test_fast_likelihood.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:thejoker 作者: adrn 项目源码 文件源码
def test_shit():
    joker_params = JokerParams(P_min=8*u.day, P_max=32768*u.day, jitter=0*u.m/u.s)
    joker = TheJoker(joker_params)

    t = np.random.uniform(0, 250, 16) + 56831.324
    t.sort()

    rv = np.cos(t)
    rv_err = np.random.uniform(0.1, 0.2, t.size)

    data = RVData(t=t, rv=rv*u.km/u.s, stddev=rv_err*u.km/u.s)

    samples = joker.sample_prior(size=16384)

    chunk = []
    for k in samples:
        chunk.append(np.array(samples[k]))

    chunk = np.ascontiguousarray(np.vstack(chunk).T)

    t0 = time.time()
    cy_ll = batch_marginal_ln_likelihood(chunk, data, joker_params)
    print("Cython:", time.time() - t0)

    t0 = time.time()
    n_chunk = len(chunk)
    py_ll = np.zeros(n_chunk)
    for i in range(n_chunk):
        try:
            py_ll[i] = marginal_ln_likelihood(chunk[i], data, joker_params)
        except Exception as e:
            py_ll[i] = np.nan
    print("Python:", time.time() - t0)

    assert np.allclose(np.array(cy_ll), py_ll)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号