densenet_multi_gpu.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:cifar-10-cnn 作者: BIGBALLON 项目源码 文件源码
def to_multi_gpu(model, n_gpus=2):
    if n_gpus ==1:
        return model

    with tf.device('/cpu:0'):
        x = Input(model.input_shape[1:])
    towers = []
    for g in range(n_gpus):
        with tf.device('/gpu:' + str(g)):
            slice_g = Lambda(slice_batch, lambda shape: shape, arguments={'n_gpus':n_gpus, 'part':g})(x)
            towers.append(model(slice_g))

    with tf.device('/cpu:0'):
        merged = Concatenate(axis=0)(towers)
    return Model(inputs=[x], outputs=merged)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号