def distribute_input_data(data_loader,num_gpu):
'''
data_loader is a dictionary of tensors that are fed into our model
This function takes that dictionary of n*batch_size dimension tensors
and breaks it up into n dictionaries with the same key of tensors with
dimension batch_size. One is given to each gpu
'''
if num_gpu==0:
return {'/cpu:0':data_loader}
gpus=get_available_gpus()
if num_gpu > len(gpus):
raise ValueError('number of gpus specified={}, more than gpus available={}'.format(num_gpu,len(gpus)))
gpus=gpus[:num_gpu]
data_by_gpu={g:{} for g in gpus}
for key,value in data_loader.items():
spl_vals=tf.split(value,num_gpu)
for gpu,val in zip(gpus,spl_vals):
data_by_gpu[gpu][key]=val
return data_by_gpu
评论列表
文章目录