segment.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:jack 作者: uclmr 项目源码 文件源码
def segment_sample_select(probs, segment_ids):
    num_segments = tf.reduce_max(segment_ids) + 1
    sampled = tf.random_uniform([num_segments])

    def scan_fn(acc, x):
        p, i = x[0], x[1]
        prev_v = tf.gather(acc[0], i)
        new_probs = acc[0] + tf.one_hot(i, num_segments, p)
        select = tf.logical_and(tf.less(prev_v, 0.0), tf.greater_equal(prev_v + p, 0.0))
        return new_probs, select

    _, selection = tf.scan(scan_fn, (probs, segment_ids), initializer=(-sampled, False))

    return selection
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号