def test_time(self):
"""Test that a `time` over the `length` triggers a finished flag."""
tf.set_random_seed(23)
time = tf.convert_to_tensor(5, dtype=tf.int32)
lengths = tf.constant([4, 5, 6, 7])
output = tf.random_normal([4, 10, 3], dtype=tf.float32)
finished = layers.TerminationHelper(lengths).finished(time, output)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
act_finished = sess.run(finished)
# NOTA BENE: we have set that
# time = 5
# lengths = [4, 5, 6, 7]
#
# Since the time is 0-based, having time=5 means that
# we have alread scanned through 5 elements, so only
# the last sequence in the batch is ongoing.
exp_finished = [True, True, True, False]
self.assertAllEqual(exp_finished, act_finished)
评论列表
文章目录