mrt_utils.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:THUMT 作者: thumt 项目源码 文件源码
def sampler(symbols_to_logits_fn, initial_ids, sample_num, decode_length,
            vocab_size, eos_id, features=None):
    batch_size = tf.shape(initial_ids)[0]

    # Expand each batch to sample_num
    seqlen = tf.constant(0)
    alive_seq = tf.tile(tf.expand_dims(initial_ids, 1), [1, sample_num])
    alive_seq = tf.expand_dims(alive_seq, 2)  # (batch_size, sample_num, 1)
    sa = tf.shape(alive_seq)
    alive_seq = tf.reshape(alive_seq, [sa[0]*sa[1],1])

    def _is_finished(i, alive_seq):
        return i < decode_length

    def inner_loop(i, alive_seq):
        logit = symbols_to_logits_fn(alive_seq)[0]
        new_samples = tf.multinomial(logit, 1)
        new_samples = tf.to_int32(new_samples)
        alive_seq = tf.concat([alive_seq, new_samples], 1)
        return (i + 1, alive_seq)

    (_, alive_seq) = tf.while_loop(
        _is_finished,
        inner_loop,
        [seqlen, alive_seq],
        shape_invariants=[
            tf.TensorShape([]),
            tf.TensorShape([None, None])
        ],
        parallel_iterations=1,
        back_prop=False
    )
    alive_seq.set_shape((sample_num, None))

    return alive_seq
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号