def test_build_hierachical_stacked_lstm_network_with_merge_correct_slice(self):
input_shape = 14
sequence_length = 4
batch_size = 1
_, l_lstm, l_slice = build_hierachical_stacked_lstm_network_with_merge(
input_shape=input_shape,
sequence_length=sequence_length,
batch_size=batch_size,
output_shape=4)
states = T.tensor3('states')
lstm_out = lasagne.layers.get_output(l_lstm, states)
slice_out = lasagne.layers.get_output(l_slice, states)
run = theano.function([states], [lstm_out, slice_out])
sample_states = np.zeros((batch_size, sequence_length, input_shape))
sample_lstm_out, sample_slice_out = run(sample_states)
self.assertEquals(sample_lstm_out[:, 1::2, :].tolist(), sample_slice_out.tolist())
评论列表
文章目录