test_distributions.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def test_batch_log_pdf_mask(dist):
    if dist.get_test_distribution_name() not in ('Normal', 'Bernoulli', 'Categorical'):
        pytest.skip('Batch pdf masking not supported for the distribution.')
    d = dist.pyro_dist
    for idx in range(dist.get_num_test_data()):
        dist_params = dist.get_dist_params(idx)
        x = dist.get_test_data(idx)
        with xfail_if_not_implemented():
            batch_pdf_shape = d.batch_shape(**dist_params) + (1,)
            batch_pdf_shape_broadcasted = d.batch_shape(x, **dist_params) + (1,)
            zeros_mask = ng_zeros(1)  # should be broadcasted to data dims
            ones_mask = ng_ones(batch_pdf_shape)  # should be broadcasted to data dims
            half_mask = ng_ones(1) * 0.5
            batch_log_pdf = d.batch_log_pdf(x, **dist_params)
            batch_log_pdf_zeros_mask = d.batch_log_pdf(x, log_pdf_mask=zeros_mask, **dist_params)
            batch_log_pdf_ones_mask = d.batch_log_pdf(x, log_pdf_mask=ones_mask, **dist_params)
            batch_log_pdf_half_mask = d.batch_log_pdf(x, log_pdf_mask=half_mask, **dist_params)
            assert_equal(batch_log_pdf_ones_mask, batch_log_pdf)
            assert_equal(batch_log_pdf_zeros_mask, ng_zeros(batch_pdf_shape_broadcasted))
            assert_equal(batch_log_pdf_half_mask, 0.5 * batch_log_pdf)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号