debug.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:chainer-qrnn 作者: musyoku 项目源码 文件源码
def test_rnn():
    np.random.seed(0)
    num_layers = 50
    seq_length = num_layers * 2
    batchsize = 2
    vocab_size = 4
    data = np.random.randint(0, vocab_size, size=(batchsize, seq_length), dtype=np.int32)
    source, target = make_source_target_pair(data)
    model = RNNModel(vocab_size, ndim_embedding=100, num_layers=num_layers, ndim_h=3, kernel_size=3, pooling="fo", zoneout=False, wgain=1, densely_connected=True)

    with chainer.using_config("train", False):
        np.random.seed(0)
        model.reset_state()
        Y = model(source).data

        model.reset_state()
        np.random.seed(0)
        for t in range(source.shape[1]):
            y = model.forward_one_step(source[:, :t+1]).data
            target = np.swapaxes(np.reshape(Y, (batchsize, -1, vocab_size)), 1, 2)
            target = np.reshape(np.swapaxes(target[:, :, t, None], 1, 2), (batchsize, -1))
            assert np.sum((y - target) ** 2) == 0
            print("t = {} OK".format(t))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号