test_multi_node_chain_list.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:chainermn 作者: chainer 项目源码 文件源码
def check_branching_model(gpu, communicator, rank_next, rank_prev,
                          parent_model):
    n, d = 100, 10
    X = np.random.randn(n, d).astype(np.float32)
    Y = (np.random.rand(n) * 2).astype(np.int32)

    if communicator.rank == 0:
        rank_children = [rank for rank in range(1, communicator.size)]
        model = L.Classifier(parent_model(
            d, communicator, rank_children))
        if gpu:
            model.to_gpu()
            X = chainer.cuda.to_gpu(X)
            Y = chainer.cuda.to_gpu(Y)

        for i in range(n):
            err = model(X[i:i + 1], Y[i:i + 1])
            err.backward()
    else:
        model = BranchChild(d, communicator, 0)
        if gpu:
            model.to_gpu()

        for i in range(n):
            err = model()
            err.backward()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号