build.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:minos 作者: guybedo 项目源码 文件源码
def _build_multi_gpu_model(blueprint, devices):
    import tensorflow as tf
    model = _build_single_device_model(blueprint, cpu_device())
    gpu_devices = [d for d in devices if is_gpu_device(d)]
    gpu_count = len(gpu_devices)

    def get_input(data, idx, parts):
        shape = tf.shape(data)
        size = tf.concat([ shape[:1] // parts, shape[1:] ], 0)
        stride = tf.concat([ shape[:1] // parts, shape[1:]*0 ], 0)
        start = stride * idx
        return tf.slice(data, start, size)

    outputs = []
    for i, device in enumerate(gpu_devices):
        with tf.device(device):
            x = model.inputs[0]
            input_shape = tuple(x.get_shape().as_list())[1:]
            model_input = Lambda(
                get_input, 
                output_shape=input_shape, 
                arguments={'idx':i,'parts':gpu_count})(x)
            outputs.append(model(model_input))
    with tf.device(cpu_device()):
        output = merge(outputs, mode='concat', concat_axis=0)
        return MultiGpuModel(
            model, 
            model_input=model.inputs, 
            model_output=output)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号