def split_input(inputs, num_gpus=1):
if not isinstance(num_gpus, list):
n_gpus = num_gpus
else:
n_gpus = len(num_gpus)
if n_gpus == 1:
return [inputs]
temp_args = {v: tf.split(inputs[v], axis=0, num_or_size_splits=num_gpus)
for v in inputs}
list_of_args = [{now_arg: temp_args[now_arg][ind]
for now_arg in temp_args} for ind in xrange(n_gpus)]
return list_of_args
评论列表
文章目录