def test_init_n(self):
dist = Binomial(tf.ones([2]), 10)
self.assertTrue(isinstance(dist.n_experiments, int))
self.assertEqual(dist.n_experiments, 10)
with self.assertRaisesRegexp(ValueError, "must be positive"):
_ = Binomial(tf.ones([2]), 0)
with self.test_session(use_gpu=True):
logits = tf.placeholder(tf.float32, None)
n_experiments = tf.placeholder(tf.int32, None)
dist2 = Binomial(logits, n_experiments)
with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
"should be a scalar"):
dist2.n_experiments.eval(feed_dict={logits: [1.],
n_experiments: [10]})
with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
"must be positive"):
dist2.n_experiments.eval(feed_dict={logits: [1.],
n_experiments: 0})
评论列表
文章目录