test_build_network.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:hierarchical_rl 作者: wulfebw 项目源码 文件源码
def test_build_hierachical_stacked_lstm_network_with_merge_correct_slice(self):
        input_shape = 14
        sequence_length = 4
        batch_size = 1
        _, l_lstm, l_slice = build_hierachical_stacked_lstm_network_with_merge(
                                    input_shape=input_shape,
                                    sequence_length=sequence_length,
                                    batch_size=batch_size,
                                    output_shape=4)

        states = T.tensor3('states')
        lstm_out = lasagne.layers.get_output(l_lstm, states)
        slice_out = lasagne.layers.get_output(l_slice, states)
        run = theano.function([states], [lstm_out, slice_out])
        sample_states = np.zeros((batch_size, sequence_length, input_shape))
        sample_lstm_out, sample_slice_out = run(sample_states)

        self.assertEquals(sample_lstm_out[:, 1::2, :].tolist(), sample_slice_out.tolist())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号