utils.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:CausalGAN 作者: mkocaoglu 项目源码 文件源码
def distribute_input_data(data_loader,num_gpu):
    '''
    data_loader is a dictionary of tensors that are fed into our model

    This function takes that dictionary of n*batch_size dimension tensors
    and breaks it up into n dictionaries with the same key of tensors with
    dimension batch_size. One is given to each gpu
    '''
    if num_gpu==0:
        return {'/cpu:0':data_loader}

    gpus=get_available_gpus()
    if num_gpu > len(gpus):
        raise ValueError('number of gpus specified={}, more than gpus available={}'.format(num_gpu,len(gpus)))

    gpus=gpus[:num_gpu]


    data_by_gpu={g:{} for g in gpus}
    for key,value in data_loader.items():
        spl_vals=tf.split(value,num_gpu)
        for gpu,val in zip(gpus,spl_vals):
            data_by_gpu[gpu][key]=val

    return data_by_gpu
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号