test_layers.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:LiTeFlow 作者: petrux 项目源码 文件源码
def test_time(self):
        """Test that a `time` over the `length` triggers a finished flag."""
        tf.set_random_seed(23)
        time = tf.convert_to_tensor(5, dtype=tf.int32)
        lengths = tf.constant([4, 5, 6, 7])
        output = tf.random_normal([4, 10, 3], dtype=tf.float32)
        finished = layers.TerminationHelper(lengths).finished(time, output)

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            act_finished = sess.run(finished)

        # NOTA BENE: we have set that
        # time = 5
        # lengths = [4, 5, 6, 7]
        #
        # Since the time is 0-based, having time=5 means that
        # we have alread scanned through 5 elements, so only
        # the last sequence in the batch is ongoing.
        exp_finished = [True, True, True, False]
        self.assertAllEqual(exp_finished, act_finished)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号