def _conditional_batch(tensors, accept_prob, batch_size, queue_threads=10):
"""Conditionally enqueue tensors based on accept_prob.
Specifically, enqueue the element if accept_prob > rand_unif([0, 1]).
Args:
tensors: List of tensors to enqueue.
accept_prob: Acceptance probability per example.
batch_size: Size of batch.
queue_threads: Number of threads enqueuing in the final queue.
Returns:
List of batched tensors.
Raises:
ValueError: `accept_prob` isn't 0D.
"""
accept_prob.get_shape().assert_has_rank(0)
# Determine shapes and types of to-be-enqueued-tensors.
shapes_list = []
dtypes_list = []
for tensor in tensors:
cur_shape = tensor.get_shape()
cur_shape.assert_is_fully_defined()
shapes_list.append(cur_shape)
dtypes_list.append(tensor.dtype)
final_q = data_flow_ops.FIFOQueue(capacity=batch_size,
shapes=shapes_list,
dtypes=dtypes_list,
name='batched_queue')
logging_ops.scalar_summary('queue/%s/size' % final_q.name, final_q.size())
# Conditionally enqueue.
# Reshape enqueue op to match no_op's shape.
eq_tf = math_ops.less(random_ops.random_uniform([]), accept_prob)
conditional_enqueue = control_flow_ops.cond(
eq_tf,
lambda: final_q.enqueue(tensors),
control_flow_ops.no_op)
queue_runner.add_queue_runner(queue_runner.QueueRunner(
final_q, [conditional_enqueue] * queue_threads))
out_tensor = final_q.dequeue_many(batch_size)
# Queues return a single tensor if the list of enqued tensors is one. Since we
# want the type to be the same in all cases, always return a list.
if isinstance(out_tensor, ops.Tensor):
out_tensor = [out_tensor]
return out_tensor
评论列表
文章目录