def test_batch_log_pdf_mask(dist):
if dist.get_test_distribution_name() not in ('Normal', 'Bernoulli', 'Categorical'):
pytest.skip('Batch pdf masking not supported for the distribution.')
d = dist.pyro_dist
for idx in range(dist.get_num_test_data()):
dist_params = dist.get_dist_params(idx)
x = dist.get_test_data(idx)
with xfail_if_not_implemented():
batch_pdf_shape = d.batch_shape(**dist_params) + (1,)
batch_pdf_shape_broadcasted = d.batch_shape(x, **dist_params) + (1,)
zeros_mask = ng_zeros(1) # should be broadcasted to data dims
ones_mask = ng_ones(batch_pdf_shape) # should be broadcasted to data dims
half_mask = ng_ones(1) * 0.5
batch_log_pdf = d.batch_log_pdf(x, **dist_params)
batch_log_pdf_zeros_mask = d.batch_log_pdf(x, log_pdf_mask=zeros_mask, **dist_params)
batch_log_pdf_ones_mask = d.batch_log_pdf(x, log_pdf_mask=ones_mask, **dist_params)
batch_log_pdf_half_mask = d.batch_log_pdf(x, log_pdf_mask=half_mask, **dist_params)
assert_equal(batch_log_pdf_ones_mask, batch_log_pdf)
assert_equal(batch_log_pdf_zeros_mask, ng_zeros(batch_pdf_shape_broadcasted))
assert_equal(batch_log_pdf_half_mask, 0.5 * batch_log_pdf)
评论列表
文章目录