def test_output(self):
"""Test the DynamicDecoder.output() method."""
helper = mock.Mock()
decoder = mock.Mock()
zero_output = tf.constant([[0, 0, 0], [0, 0, 0]], dtype=tf.float32)
decoder.zero_output.side_effect = [zero_output]
output = tf.constant([[23, 23, 23], [23, 23, 23]], dtype=tf.float32)
finished = tf.constant([True, False], dtype=tf.bool)
dyndec = layers.DynamicDecoder(decoder, helper)
act_output_t = dyndec.output(output, finished)
exp_output = np.asarray([[0, 0, 0], [23, 23, 23]], dtype=np.float32) # pylint: disable=I0011,E1101
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
act_output = sess.run(act_output_t)
helper.finished.assert_not_called()
decoder.zero_output.assert_called_once()
self.assertAllEqual(exp_output, act_output)
评论列表
文章目录