def add_sync_queues_and_barrier(self, name_prefix, enqueue_after_list):
"""Adds ops to enqueue on all worker queues.
Args:
name_prefix: prefixed for the shared_name of ops.
enqueue_after_list: control dependency from ops.
Returns:
an op that should be used as control dependency before starting next step.
"""
self.sync_queue_counter += 1
with tf.device(self.sync_queue_devices[(
self.sync_queue_counter % len(self.sync_queue_devices))]):
sync_queues = [
tf.FIFOQueue(self.num_workers, [tf.bool], shapes=[[]],
shared_name='%s%s' % (name_prefix, i))
for i in range(self.num_workers)]
queue_ops = []
# For each other worker, add an entry in a queue, signaling that it can
# finish this step.
token = tf.constant(False)
with tf.control_dependencies(enqueue_after_list):
for i, q in enumerate(sync_queues):
if i == self.task_index:
queue_ops.append(tf.no_op())
else:
queue_ops.append(q.enqueue(token))
# Drain tokens off queue for this worker, one for each other worker.
queue_ops.append(
sync_queues[self.task_index].dequeue_many(len(sync_queues) - 1))
return tf.group(*queue_ops)
评论列表
文章目录