test_recurrent.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:seq2seq-lasagne 作者: erfannoury 项目源码 文件源码
def test_lnlstm_unroll_scan_bck():
    num_batch, seq_len, n_features1 = 2, 3, 4
    num_units = 2
    x = T.tensor3()
    in_shp = (num_batch, seq_len, n_features1)
    l_inp = InputLayer(in_shp)

    x_in = np.random.random(in_shp).astype('float32')

    # need to set random seed.
    lasagne.random.get_rng().seed(1234)
    l_lstm_scan = LNLSTMLayer(l_inp, num_units=num_units, backwards=True,
                            unroll_scan=False)
    lasagne.random.get_rng().seed(1234)
    l_lstm_unrolled = LNLSTMLayer(l_inp, num_units=num_units, backwards=True,
                                unroll_scan=True)
    output_scan = helper.get_output(l_lstm_scan, x)
    output_scan_unrolled = helper.get_output(l_lstm_unrolled, x)

    output_scan_val = output_scan.eval({x: x_in})
    output_unrolled_val = output_scan_unrolled.eval({x: x_in})

    np.testing.assert_almost_equal(output_scan_val, output_unrolled_val)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号