test_multi_node_chain_list.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:chainermn 作者: chainer 项目源码 文件源码
def check_tuple_data_model(gpu):
    # This test only uses pairs (0, 1), (2, 3), ... (2m, 2m+1)
    communicator, rank_next, rank_prev = create_communicator(gpu)

    n, d = 100, 10
    X = np.random.randn(n, d).astype(np.float32)
    Y = (np.random.rand(n) * 2).astype(np.int32)

    if communicator.rank % 2 == 0:
        if communicator.rank == communicator.size - 1:
            # in case 2m is the right end with odd number of nodes
            return
        model = L.Classifier(
            TupleDataParent(communicator, d, rank_next))
    elif communicator.rank % 2 == 1:
        model = TupleDataChild(communicator, d, rank_prev)

    assert model is not None
    if gpu:
        model.to_gpu()
        X = chainer.cuda.to_gpu(X)
        Y = chainer.cuda.to_gpu(Y)

    for i in range(n):
        if communicator.rank % 2 == 0:
            err = model(X[i:i + 1], Y[i:i + 1])
        elif communicator.rank % 2 == 1:
            err = model()
        assert err is not None
        err.backward()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号