def test_build_hierachical_stacked_lstm_network_with_merge_correct_slice_shared_var(self):
input_shape = 14
sequence_length = 1
batch_size = 1
_, l_lstm, l_slice = build_hierachical_stacked_lstm_network_with_merge(
input_shape=input_shape,
sequence_length=sequence_length,
batch_size=batch_size,
output_shape=4)
states = T.tensor3('states')
lstm_out = lasagne.layers.get_output(l_lstm, states)
slice_out = lasagne.layers.get_output(l_slice, states)
states_shared = theano.shared(np.zeros((batch_size, sequence_length, input_shape)))
run = theano.function([], [lstm_out, slice_out], givens={states: states_shared})
sample_states = np.zeros((batch_size, sequence_length, input_shape))
states_shared.set_value(sample_states)
sample_lstm_out, sample_slice_out = run()
self.assertEquals(sample_lstm_out[:, 1::2, :].tolist(), sample_slice_out.tolist())
评论列表
文章目录