def test_next_inp_without_decoder_inputs(self): # pylint: disable=C0103
"""Test the .next_inp method when decoder inputs are not provided."""
input_size = 4
output_value = [[1, 1, 1], [2, 2, 2], [3, 3, 3]]
states = tf.random_normal([3, 10, 4])
output = tf.constant(output_value, dtype=tf.float32)
time = tf.constant(random.randint(0, 100), dtype=tf.int32) # irrelevant
cell = mock.Mock()
location_softmax = mock.Mock()
location_softmax.attention.states = states
pointing_output = mock.Mock()
decoder = layers.PointingSoftmaxDecoder(
cell=cell, location_softmax=location_softmax,
pointing_output=pointing_output, input_size=input_size)
next_inp_t = decoder.next_inp(time, output)
# pylint: disable=E1101
next_inp_exp = np.asarray([[1, 1, 1, 0], [2, 2, 2, 0], [3, 3, 3, 0]], dtype=np.float32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
next_inp_act = sess.run(next_inp_t)
self.assertAllEqual(next_inp_exp, next_inp_act)
评论列表
文章目录