def testGridRNNEdgeCasesLikeRelu(self):
with self.test_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([3, 2])
m = array_ops.zeros([0, 0])
# this is equivalent to relu
cell = grid_rnn_cell.GridRNNCell(
num_units=2,
num_dims=1,
input_dims=0,
output_dims=0,
non_recurrent_dims=0,
non_recurrent_fn=nn_ops.relu)
g, s = cell(x, m)
self.assertEqual(g.get_shape(), (3, 2))
self.assertEqual(s.get_shape(), (0, 0))
sess.run([variables.global_variables_initializer()])
res = sess.run([g, s], {x: np.array([[1., -1.], [-2, 1], [2, -1]])})
self.assertEqual(res[0].shape, (3, 2))
self.assertEqual(res[1].shape, (0, 0))
self.assertAllClose(res[0], [[0, 0], [0, 0], [0.5, 0.5]])
grid_rnn_test.py 文件源码
python
阅读 26
收藏 0
点赞 0
评论 0
评论列表
文章目录