def testGrid2LSTMCellReLUWithRNN(self):
batch_size = 3
input_size = 5
max_length = 6 # unrolled up to this length
num_units = 2
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
cell = grid_rnn_cell.Grid2LSTMCell(
num_units=num_units, non_recurrent_fn=nn_ops.relu)
inputs = max_length * [
array_ops.placeholder(
dtypes.float32, shape=(batch_size, input_size))
]
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
self.assertEqual(len(outputs), len(inputs))
self.assertEqual(state.get_shape(), (batch_size, 4))
for out, inp in zip(outputs, inputs):
self.assertEqual(out.get_shape()[0], inp.get_shape()[0])
self.assertEqual(out.get_shape()[1], num_units)
self.assertEqual(out.dtype, inp.dtype)
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
for v in values:
self.assertTrue(np.all(np.isfinite(v)))
grid_rnn_test.py 文件源码
python
阅读 30
收藏 0
点赞 0
评论 0
评论列表
文章目录