def _get_sharding_func(size, num_shards):
"""Create sharding function for scatter update."""
def func(ids):
if num_shards == 1:
return None, ids
else:
ids_per_shard = size // num_shards
extras = size % num_shards
assignments = math_ops.maximum(ids // (ids_per_shard + 1),
(ids - extras) // ids_per_shard)
new_ids = array_ops.where(assignments < extras,
ids % (ids_per_shard + 1),
(ids - extras) % ids_per_shard)
return assignments, new_ids
return func
factorization_ops.py 文件源码
python
阅读 26
收藏 0
点赞 0
评论 0
评论列表
文章目录