def scatter_update(cls, factor, indices, values, sharding_func):
"""Helper function for doing sharded scatter update."""
assert isinstance(factor, list)
if len(factor) == 1:
with ops.colocate_with(factor[0]):
# TODO(agarwal): assign instead of scatter update for full batch update.
return tf.scatter_update(factor[0], indices, values).op
else:
num_shards = len(factor)
assignments, new_ids = sharding_func(indices)
assert assignments is not None
assignments = tf.cast(assignments, tf.int32)
sharded_ids = tf.dynamic_partition(new_ids, assignments, num_shards)
sharded_values = tf.dynamic_partition(values, assignments, num_shards)
updates = []
for i in xrange(num_shards):
updates.append(tf.scatter_update(factor[i],
sharded_ids[i],
sharded_values[i]))
return tf.group(*updates)
评论列表
文章目录