test_layers.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:LiTeFlow 作者: petrux 项目源码 文件源码
def test_decode_one_step(self):
        """Default test for the DynamicDecoder.decode() method."""

        init_value = [[.1, .1], [.2, .2], [.3, .3]]
        init_input = tf.constant(init_value)
        init_state = 2 * init_input
        next_input = 3 * init_input
        next_state = 4 * init_input
        output = 10 * init_input
        finished = tf.constant([False, False, False], dtype=tf.bool)
        zero_output = tf.zeros_like(output)

        decoder = mock.Mock()
        decoder.init_input.side_effect = [init_input]
        decoder.init_state.side_effect = [init_state]
        decoder.zero_output.side_effect = [zero_output]
        decoder.step.side_effect = [(output, next_input, next_state, finished)]

        helper = mock.Mock()
        helper.finished.side_effect = [tf.logical_not(finished)]  # exit from the loop!

        dyndec = layers.DynamicDecoder(decoder, helper)
        output_t, state_t = dyndec.decode()

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            output_act, state_act = sess.run([output_t, state_t])

            # assertions on output.
            output_exp = 10 * np.transpose(np.asarray([init_value]), (1, 0, 2))
            self.assertAllClose(output_exp, output_act)
            state_exp = 4 * np.asarray(init_value)
            self.assertAllClose(state_exp, state_act)

            # mock assertions.
            # we cannot assert more than this since the while
            # loop makes all the ops non-fetchable.
            decoder.init_input.assert_called_once()
            decoder.init_state.assert_called_once()
            decoder.zero_output.assert_called_once()
            decoder.step.assert_called_once()
            helper.finished.assert_called_once()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号