grid_rnn_test.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def testGridRNNEdgeCasesNoOutput(self):
    with self.test_session() as sess:
      with variable_scope.variable_scope(
          'root', initializer=init_ops.constant_initializer(0.5)):
        x = array_ops.zeros([1, 2])
        m = array_ops.zeros([1, 4])

        # This cell produces no output
        cell = grid_rnn_cell.GridRNNCell(
            num_units=2,
            num_dims=2,
            input_dims=0,
            output_dims=None,
            non_recurrent_dims=0,
            non_recurrent_fn=nn_ops.relu)
        g, s = cell(x, m)
        self.assertEqual(g.get_shape(), (0, 0))
        self.assertEqual(s.get_shape(), (1, 4))

        sess.run([variables.global_variables_initializer()])
        res = sess.run(
            [g, s],
            {x: np.array([[1., 1.]]),
             m: np.array([[0.1, 0.1, 0.1, 0.1]])})
        self.assertEqual(res[0].shape, (0, 0))
        self.assertEqual(res[1].shape, (1, 4))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号