bytenet.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:master-thesis 作者: AndreasMadsen 项目源码 文件源码
def sample_inference_model(self,
                               source: tf.Tensor, length: tf.Tensor,
                               samples=1,
                               reuse: bool=False) -> tf.Tensor:
        x = tf.cast(source, tf.int32)

        logprops, labels = bytenet_sampling_translator(
            x,
            beam_size=samples,
            **self._parameters,
            name="bytenet-model",
            reuse=reuse
        )

        # check if <eos> exists in each sequence
        # eos_found.shape = (batch, beam)
        eos_found = tf.reduce_any(tf.equal(labels, 1), axis=2)
        # set properbility to something very small if <eos> was not found
        # log(epsilon) = -1e9
        log_eps = tf.constant(-1e9, dtype=logprops.dtype)
        logprops = tf.where(eos_found,
                            logprops,
                            tf.fill(tf.shape(logprops), log_eps))

        # sort by logprops
        _, indices = tf.nn.top_k(logprops, k=samples, sorted=True)
        labels = batch_beam_gather(labels, indices)

        return tf.cast(labels, source.dtype)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号