factorization_ops.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def _get_sharding_func(size, num_shards):
    """Create sharding function for scatter update."""

    def func(ids):
      if num_shards == 1:
        return None, ids
      else:
        ids_per_shard = size // num_shards
        extras = size % num_shards
        assignments = math_ops.maximum(ids // (ids_per_shard + 1),
                                       (ids - extras) // ids_per_shard)
        new_ids = array_ops.where(assignments < extras,
                                  ids % (ids_per_shard + 1),
                                  (ids - extras) % ids_per_shard)
        return assignments, new_ids

    return func
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号