def test_log_pdf_on_transformed_distribution(lognormal):
mu_z = Variable(torch.zeros(1))
sigma_z = Variable(torch.ones(1))
dist_params = lognormal.get_dist_params(0)
mu_lognorm = dist_params['mu']
sigma_lognorm = dist_params['sigma']
trans_dist = get_transformed_dist(dist.normal, sigma_lognorm, mu_lognorm)
test_data = lognormal.get_test_data(0)
log_px_torch = trans_dist.log_pdf(test_data, mu_z, sigma_z).data[0]
log_px_np = sp.lognorm.logpdf(
test_data.data.cpu().numpy(),
sigma_lognorm.data.cpu().numpy(),
scale=np.exp(mu_lognorm.data.cpu().numpy()))[0]
assert_equal(log_px_torch, log_px_np, prec=1e-4)
评论列表
文章目录