def test_decode_one_step(self):
"""Default test for the DynamicDecoder.decode() method."""
init_value = [[.1, .1], [.2, .2], [.3, .3]]
init_input = tf.constant(init_value)
init_state = 2 * init_input
next_input = 3 * init_input
next_state = 4 * init_input
output = 10 * init_input
finished = tf.constant([False, False, False], dtype=tf.bool)
zero_output = tf.zeros_like(output)
decoder = mock.Mock()
decoder.init_input.side_effect = [init_input]
decoder.init_state.side_effect = [init_state]
decoder.zero_output.side_effect = [zero_output]
decoder.step.side_effect = [(output, next_input, next_state, finished)]
helper = mock.Mock()
helper.finished.side_effect = [tf.logical_not(finished)] # exit from the loop!
dyndec = layers.DynamicDecoder(decoder, helper)
output_t, state_t = dyndec.decode()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
output_act, state_act = sess.run([output_t, state_t])
# assertions on output.
output_exp = 10 * np.transpose(np.asarray([init_value]), (1, 0, 2))
self.assertAllClose(output_exp, output_act)
state_exp = 4 * np.asarray(init_value)
self.assertAllClose(state_exp, state_act)
# mock assertions.
# we cannot assert more than this since the while
# loop makes all the ops non-fetchable.
decoder.init_input.assert_called_once()
decoder.init_state.assert_called_once()
decoder.zero_output.assert_called_once()
decoder.step.assert_called_once()
helper.finished.assert_called_once()
评论列表
文章目录