test_build_network.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:hierarchical_rl 作者: wulfebw 项目源码 文件源码
def test_build_hierachical_stacked_lstm_network_with_merge_correct_slice_longer_len_seq(self):
        input_shape = 14
        sequence_length = 7
        batch_size = 1
        l_out, l_lstm, l_slice = build_hierachical_stacked_lstm_network_with_merge(
                                    input_shape=input_shape,
                                    sequence_length=sequence_length,
                                    batch_size=batch_size,
                                    output_shape=4,
                                    start=0,
                                    downsample=3)

        states = T.tensor3('states')
        l_out_out = lasagne.layers.get_output(l_out, states)
        lstm_out = lasagne.layers.get_output(l_lstm, states)
        slice_out = lasagne.layers.get_output(l_slice, states)
        run = theano.function([states], [l_out_out, lstm_out, slice_out])
        sample_states = np.zeros((batch_size, sequence_length, input_shape))
        sample_out, sample_lstm_out, sample_slice_out = run(sample_states)

        self.assertEquals(sample_lstm_out[:, 0::3, :].tolist(), sample_slice_out.tolist())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号