def test_seq2seq_inputs(self):
inp = np.array([[[1, 0], [0, 1], [1, 0]], [[0, 1], [1, 0], [0, 1]]])
out = np.array([[[0, 1, 0], [1, 0, 0]], [[1, 0, 0], [0, 1, 0]]])
with self.test_session() as session:
x = tf.placeholder(tf.float32, [2, 3, 2])
y = tf.placeholder(tf.float32, [2, 2, 3])
in_x, in_y, out_y = ops.seq2seq_inputs(x, y, 3, 2)
enc_inp = session.run(in_x, feed_dict={x.name: inp})
dec_inp = session.run(in_y, feed_dict={x.name: inp, y.name: out})
dec_out = session.run(out_y, feed_dict={x.name: inp, y.name: out})
# Swaps from batch x len x height to list of len of batch x height.
self.assertAllEqual(enc_inp, np.swapaxes(inp, 0, 1))
self.assertAllEqual(dec_inp, [[[0, 0, 0], [0, 0, 0]],
[[0, 1, 0], [1, 0, 0]],
[[1, 0, 0], [0, 1, 0]]])
self.assertAllEqual(dec_out, [[[0, 1, 0], [1, 0, 0]],
[[1, 0, 0], [0, 1, 0]],
[[0, 0, 0], [0, 0, 0]]])
评论列表
文章目录