def testGrid2LSTMCellTied(self):
with self.test_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
m = array_ops.zeros([1, 8])
cell = grid_rnn_cell.Grid2LSTMCell(2, tied=True, use_peepholes=True)
self.assertEqual(cell.state_size, 8)
g, s = cell(x, m)
self.assertEqual(g.get_shape(), (1, 2))
self.assertEqual(s.get_shape(), (1, 8))
sess.run([variables.global_variables_initializer()])
res = sess.run([g, s], {
x: np.array([[1., 1., 1.]]),
m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])
})
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 8))
self.assertAllClose(res[0], [[0.95686918, 0.95686918]])
self.assertAllClose(res[1], [[2.41515064, 2.41515064, 0.95686918,
0.95686918, 1.38917875, 1.49043763,
0.83884692, 0.86036491]])
grid_rnn_test.py 文件源码
python
阅读 31
收藏 0
点赞 0
评论 0
评论列表
文章目录