def test_dtype_1parameter_discrete(test_class, Distribution):
def _test_sample_dtype(input_, result_dtype, dtype):
distribution = Distribution(input_, dtype=dtype)
samples = distribution.sample(2)
test_class.assertEqual(distribution.dtype, result_dtype)
test_class.assertEqual(samples.dtype, result_dtype)
for input_ in [[1.], [[2., 3.], [4., 5.]]]:
_test_sample_dtype(input_, tf.int32, None)
_test_sample_dtype(input_, tf.int16, tf.int16)
_test_sample_dtype(input_, tf.int32, tf.int32)
_test_sample_dtype(input_, tf.float32, tf.float32)
_test_sample_dtype(input_, tf.float64, tf.float64)
def _test_parameter_dtype_raise(param_dtype):
param = tf.placeholder(param_dtype, [1])
with test_class.assertRaises(TypeError):
Distribution(param)
_test_parameter_dtype_raise(tf.int32)
_test_parameter_dtype_raise(tf.int64)
# test dtype for log_prob and prob
def _test_log_prob_dtype(param_dtype, given_dtype):
param = tf.placeholder(param_dtype, [1])
distribution = Distribution(param, dtype=given_dtype)
test_class.assertEqual(distribution.param_dtype, param_dtype)
# test for tensor
given = tf.placeholder(given_dtype, None)
prob = distribution.prob(given)
log_prob = distribution.log_prob(given)
test_class.assertEqual(prob.dtype, param_dtype)
test_class.assertEqual(log_prob.dtype, param_dtype)
# test for numpy
given_np = np.array([1], given_dtype.as_numpy_dtype)
prob_np = distribution.prob(given_np)
log_prob_np = distribution.log_prob(given_np)
test_class.assertEqual(prob_np.dtype, param_dtype)
test_class.assertEqual(log_prob_np.dtype, param_dtype)
_test_log_prob_dtype(tf.float16, tf.int32)
_test_log_prob_dtype(tf.float32, tf.int32)
_test_log_prob_dtype(tf.float64, tf.int64)
_test_log_prob_dtype(tf.float32, tf.float32)
_test_log_prob_dtype(tf.float32, tf.float64)
评论列表
文章目录