def test_condition(self):
"""Test the DynamicDecoder.condition() method."""
helper = mock.Mock()
decoder = mock.Mock()
dyndec = layers.DynamicDecoder(decoder, helper)
finished = [(tf.constant([True], dtype=tf.bool), False),
(tf.constant([False], dtype=tf.bool), True),
(tf.constant([True, False], dtype=tf.bool), True),
(tf.constant([True, True], dtype=tf.bool), False)]
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for tensor, expected in finished:
actual = sess.run(dyndec.cond(None, None, None, tensor, None))
self.assertEqual(expected, actual)
helper.assert_not_called()
decoder.assert_not_called()
评论列表
文章目录