grid_rnn_test.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def testGrid3LSTMCellReLUWithRNN(self):
    batch_size = 3
    input_size = 5
    max_length = 6  # unrolled up to this length
    num_units = 2

    with variable_scope.variable_scope(
        'root', initializer=init_ops.constant_initializer(0.5)):
      cell = grid_rnn_cell.Grid3LSTMCell(
          num_units=num_units, non_recurrent_fn=nn_ops.relu)

      inputs = max_length * [
          array_ops.placeholder(
              dtypes.float32, shape=(batch_size, input_size))
      ]

      outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)

    self.assertEqual(len(outputs), len(inputs))
    self.assertEqual(state.get_shape(), (batch_size, 8))

    for out, inp in zip(outputs, inputs):
      self.assertEqual(out.get_shape()[0], inp.get_shape()[0])
      self.assertEqual(out.get_shape()[1], num_units)
      self.assertEqual(out.dtype, inp.dtype)

    with self.test_session() as sess:
      sess.run(variables.global_variables_initializer())

      input_value = np.ones((batch_size, input_size))
      values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
      for v in values:
        self.assertTrue(np.all(np.isfinite(v)))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号