base.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:tfutils 作者: neuroailab 项目源码 文件源码
def split_input(inputs, num_gpus=1):
    if not isinstance(num_gpus, list):
        n_gpus = num_gpus
    else:
        n_gpus = len(num_gpus)

    if n_gpus == 1:
        return [inputs]

    temp_args = {v: tf.split(inputs[v], axis=0, num_or_size_splits=num_gpus)
                 for v in inputs}

    list_of_args = [{now_arg: temp_args[now_arg][ind]
                     for now_arg in temp_args} for ind in xrange(n_gpus)]

    return list_of_args
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号