factorization_ops.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def _get_sharding_func(size, num_shards):
    """Create sharding function for scatter update."""
    def func(ids):
      if num_shards == 1:
        return None, ids
      else:
        ids_per_shard = size // num_shards
        extras = size % num_shards
        assignments = tf.maximum(ids // (ids_per_shard + 1),
                                 (ids - extras) // ids_per_shard)
        new_ids = tf.select(assignments < extras,
                            ids % (ids_per_shard + 1),
                            (ids - extras) % ids_per_shard)
        return assignments, new_ids
    return func
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号