def test_iterations(self):
"""Test the number of iterations."""
lengths = tf.constant([1, 2, 3], dtype=tf.int32)
def _helper_finished(time, _):
return tf.greater_equal(time + 1, lengths)
helper = mock.Mock()
helper.finished.side_effect = _helper_finished
batch_size = utils.get_dimension(lengths, 0)
inp_size, state_size, output_size = 2, 5, 2
decoder = mock.Mock()
decoder.init_input.side_effect = lambda: tf.zeros([batch_size, inp_size])
decoder.init_state.side_effect = lambda: tf.ones([batch_size, state_size])
decoder.zero_output.side_effect = lambda: tf.zeros([batch_size, output_size])
decoder.step.side_effect = lambda t, i, s:\
((i + 1), 3 * (i + 1), (s + 2), tf.tile([False], [batch_size]))
output_exp = np.asarray(
[[[1, 1], [0, 0], [0, 0]],
[[1, 1], [4, 4], [0, 0]],
[[1, 1], [4, 4], [13, 13]]],
dtype=np.float32) # pylint: disable=E1101,I0011
state_exp = np.asarray(
[[7, 7, 7, 7, 7],
[7, 7, 7, 7, 7],
[7, 7, 7, 7, 7]],
dtype=np.float32) # pylint: disable=E1101,I0011
dyndec = layers.DynamicDecoder(decoder, helper)
output_t, state_t = dyndec.decode()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
output_act, state_act = sess.run([output_t, state_t])
self.assertAllEqual(output_exp, output_act)
self.assertAllEqual(state_exp, state_act)
评论列表
文章目录