test_transformed_distribution.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def test_log_pdf_on_transformed_distribution(lognormal):
    mu_z = Variable(torch.zeros(1))
    sigma_z = Variable(torch.ones(1))
    dist_params = lognormal.get_dist_params(0)
    mu_lognorm = dist_params['mu']
    sigma_lognorm = dist_params['sigma']
    trans_dist = get_transformed_dist(dist.normal, sigma_lognorm, mu_lognorm)
    test_data = lognormal.get_test_data(0)
    log_px_torch = trans_dist.log_pdf(test_data, mu_z, sigma_z).data[0]
    log_px_np = sp.lognorm.logpdf(
        test_data.data.cpu().numpy(),
        sigma_lognorm.data.cpu().numpy(),
        scale=np.exp(mu_lognorm.data.cpu().numpy()))[0]
    assert_equal(log_px_torch, log_px_np, prec=1e-4)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号