rnn_cell_test.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:seq2seq 作者: google 项目源码 文件源码
def test_without_residuals(self):
    inputs = tf.constant(np.random.randn(1, 2))
    state = (tf.constant(np.random.randn(1, 2)),
             tf.constant(np.random.randn(1, 2)))

    with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
      standard_cell = tf.contrib.rnn.MultiRNNCell(
          [tf.contrib.rnn.GRUCell(2) for _ in range(2)], state_is_tuple=True)
      res_standard = standard_cell(inputs, state, scope="standard")

      test_cell = rnn_cell.ExtendedMultiRNNCell(
          [tf.contrib.rnn.GRUCell(2) for _ in range(2)])
      res_test = test_cell(inputs, state, scope="test")

    with self.test_session() as sess:
      sess.run([tf.global_variables_initializer()])
      res_standard_, res_test_, = sess.run([res_standard, res_test])

    # Make sure it produces the same results as the standard cell
    self.assertAllClose(res_standard_[0], res_test_[0])
    self.assertAllClose(res_standard_[1][0], res_test_[1][0])
    self.assertAllClose(res_standard_[1][1], res_test_[1][1])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号