def segment_sample_select(probs, segment_ids):
num_segments = tf.reduce_max(segment_ids) + 1
sampled = tf.random_uniform([num_segments])
def scan_fn(acc, x):
p, i = x[0], x[1]
prev_v = tf.gather(acc[0], i)
new_probs = acc[0] + tf.one_hot(i, num_segments, p)
select = tf.logical_and(tf.less(prev_v, 0.0), tf.greater_equal(prev_v + p, 0.0))
return new_probs, select
_, selection = tf.scan(scan_fn, (probs, segment_ids), initializer=(-sampled, False))
return selection
评论列表
文章目录