utils.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:zhusuan 作者: thu-ml 项目源码 文件源码
def test_dtype_1parameter_discrete(test_class, Distribution):
    def _test_sample_dtype(input_, result_dtype, dtype):
        distribution = Distribution(input_, dtype=dtype)
        samples = distribution.sample(2)
        test_class.assertEqual(distribution.dtype, result_dtype)
        test_class.assertEqual(samples.dtype, result_dtype)

    for input_ in [[1.], [[2., 3.], [4., 5.]]]:
        _test_sample_dtype(input_, tf.int32, None)
        _test_sample_dtype(input_, tf.int16, tf.int16)
        _test_sample_dtype(input_, tf.int32, tf.int32)
        _test_sample_dtype(input_, tf.float32, tf.float32)
        _test_sample_dtype(input_, tf.float64, tf.float64)

    def _test_parameter_dtype_raise(param_dtype):
        param = tf.placeholder(param_dtype, [1])
        with test_class.assertRaises(TypeError):
            Distribution(param)

    _test_parameter_dtype_raise(tf.int32)
    _test_parameter_dtype_raise(tf.int64)

    # test dtype for log_prob and prob
    def _test_log_prob_dtype(param_dtype, given_dtype):
        param = tf.placeholder(param_dtype, [1])
        distribution = Distribution(param, dtype=given_dtype)
        test_class.assertEqual(distribution.param_dtype, param_dtype)

        # test for tensor
        given = tf.placeholder(given_dtype, None)
        prob = distribution.prob(given)
        log_prob = distribution.log_prob(given)

        test_class.assertEqual(prob.dtype, param_dtype)
        test_class.assertEqual(log_prob.dtype, param_dtype)

        # test for numpy
        given_np = np.array([1], given_dtype.as_numpy_dtype)
        prob_np = distribution.prob(given_np)
        log_prob_np = distribution.log_prob(given_np)

        test_class.assertEqual(prob_np.dtype, param_dtype)
        test_class.assertEqual(log_prob_np.dtype, param_dtype)

    _test_log_prob_dtype(tf.float16, tf.int32)
    _test_log_prob_dtype(tf.float32, tf.int32)
    _test_log_prob_dtype(tf.float64, tf.int64)
    _test_log_prob_dtype(tf.float32, tf.float32)
    _test_log_prob_dtype(tf.float32, tf.float64)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号