test_recurrent.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:seq2seq-lasagne 作者: erfannoury 项目源码 文件源码
def test_lstm_return_final():
    num_batch, seq_len, n_features = 2, 3, 4
    num_units = 2
    in_shp = (num_batch, seq_len, n_features)
    x_in = np.random.random(in_shp).astype('float32')

    l_inp = InputLayer(in_shp)
    lasagne.random.get_rng().seed(1234)
    l_rec_final = LSTMLayer(l_inp, num_units, only_return_final=True)
    lasagne.random.get_rng().seed(1234)
    l_rec_all = LSTMLayer(l_inp, num_units, only_return_final=False)

    output_final = helper.get_output(l_rec_final).eval({l_inp.input_var: x_in})
    output_all = helper.get_output(l_rec_all).eval({l_inp.input_var: x_in})

    assert output_final.shape == (output_all.shape[0], output_all.shape[2])
    assert output_final.shape == lasagne.layers.get_output_shape(l_rec_final)
    assert np.allclose(output_final, output_all[:, -1])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号