seq2seq_ops_test.py 文件源码

python
阅读 43 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def test_seq2seq_inputs(self):
    inp = np.array([[[1, 0], [0, 1], [1, 0]], [[0, 1], [1, 0], [0, 1]]])
    out = np.array([[[0, 1, 0], [1, 0, 0]], [[1, 0, 0], [0, 1, 0]]])
    with self.test_session() as session:
      x = tf.placeholder(tf.float32, [2, 3, 2])
      y = tf.placeholder(tf.float32, [2, 2, 3])
      in_x, in_y, out_y = ops.seq2seq_inputs(x, y, 3, 2)
      enc_inp = session.run(in_x, feed_dict={x.name: inp})
      dec_inp = session.run(in_y, feed_dict={x.name: inp, y.name: out})
      dec_out = session.run(out_y, feed_dict={x.name: inp, y.name: out})
    # Swaps from batch x len x height to list of len of batch x height.
    self.assertAllEqual(enc_inp, np.swapaxes(inp, 0, 1))
    self.assertAllEqual(dec_inp, [[[0, 0, 0], [0, 0, 0]],
                                  [[0, 1, 0], [1, 0, 0]],
                                  [[1, 0, 0], [0, 1, 0]]])
    self.assertAllEqual(dec_out, [[[0, 1, 0], [1, 0, 0]],
                                  [[1, 0, 0], [0, 1, 0]],
                                  [[0, 0, 0], [0, 0, 0]]])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号