def _build_multi_gpu_model(blueprint, devices):
import tensorflow as tf
model = _build_single_device_model(blueprint, cpu_device())
gpu_devices = [d for d in devices if is_gpu_device(d)]
gpu_count = len(gpu_devices)
def get_input(data, idx, parts):
shape = tf.shape(data)
size = tf.concat([ shape[:1] // parts, shape[1:] ], 0)
stride = tf.concat([ shape[:1] // parts, shape[1:]*0 ], 0)
start = stride * idx
return tf.slice(data, start, size)
outputs = []
for i, device in enumerate(gpu_devices):
with tf.device(device):
x = model.inputs[0]
input_shape = tuple(x.get_shape().as_list())[1:]
model_input = Lambda(
get_input,
output_shape=input_shape,
arguments={'idx':i,'parts':gpu_count})(x)
outputs.append(model(model_input))
with tf.device(cpu_device()):
output = merge(outputs, mode='concat', concat_axis=0)
return MultiGpuModel(
model,
model_input=model.inputs,
model_output=output)
评论列表
文章目录