def test_normal_shape_scalar_params(self):
normal = Normal(0, 1)
self.assertEqual(normal._batch_shape, torch.Size())
self.assertEqual(normal._event_shape, torch.Size())
self.assertEqual(normal.sample().size(), torch.Size((1,)))
self.assertEqual(normal.sample((3, 2)).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, normal.log_prob, self.scalar_sample)
self.assertEqual(normal.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(normal.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
评论列表
文章目录