test_multi_node_chain_list.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:chainermn 作者: chainer 项目源码 文件源码
def check_crossing_model(gpu):
    communicator, rank_next, rank_prev = create_communicator(gpu)

    n, d = 100, 10
    X = np.random.randn(n, d).astype(np.float32)
    Y = (np.random.rand(n) * 2).astype(np.int32)

    if communicator.rank == 0:
        model = L.Classifier(Cross0(
            d, communicator, rank_next, rank_prev))
    else:
        model = L.Classifier(Cross1(
            d, communicator, rank_next, rank_prev))

    if gpu:
        model.to_gpu()
        X = chainer.cuda.to_gpu(X)
        Y = chainer.cuda.to_gpu(Y)

    for i in range(n):
        err = model(X[i:i + 1], Y[i:i + 1])
        err.backward()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号