def _get_sharding_func(size, num_shards):
"""Create sharding function for scatter update."""
def func(ids):
if num_shards == 1:
return None, ids
else:
ids_per_shard = size // num_shards
extras = size % num_shards
assignments = tf.maximum(ids // (ids_per_shard + 1),
(ids - extras) // ids_per_shard)
new_ids = tf.select(assignments < extras,
ids % (ids_per_shard + 1),
(ids - extras) % ids_per_shard)
return assignments, new_ids
return func
评论列表
文章目录