test_multi_node_chain_list.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:chainermn 作者: chainer 项目源码 文件源码
def check_cycle_model(gpu):
    communicator, rank_next, rank_prev = create_communicator(gpu)

    n, d = 100, 10

    if communicator.rank == 0:
        X = np.random.randn(n, d).astype(np.float32)
        Y = (np.random.rand(n) * 2).astype(np.int32)
        model = L.Classifier(
            Cycle0(d, communicator, rank_next, rank_prev))

        if gpu:
            model.to_gpu()
            X = chainer.cuda.to_gpu(X)
            Y = chainer.cuda.to_gpu(Y)

        for i in range(n):
            err = model(X[i:i + 1], Y[i:i + 1])
            err.backward()
    else:
        model = Cycle1(
            d, communicator, rank_next, rank_prev)
        if gpu:
            model.to_gpu()

        for i in range(n):
            err = model()
            err.backward()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号