def to_multi_gpu(model, n_gpus=2):
if n_gpus ==1:
return model
with tf.device('/cpu:0'):
x = Input(model.input_shape[1:])
towers = []
for g in range(n_gpus):
with tf.device('/gpu:' + str(g)):
slice_g = Lambda(slice_batch, lambda shape: shape, arguments={'n_gpus':n_gpus, 'part':g})(x)
towers.append(model(slice_g))
with tf.device('/cpu:0'):
merged = Concatenate(axis=0)(towers)
return Model(inputs=[x], outputs=merged)
评论列表
文章目录