beamsearch.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:TextGAN 作者: ankitkv 项目源码 文件源码
def unwrap_output_sparse(self, final_state, include_stop_tokens=True):
        """
        Retreive the beam search output from the final state.

        Returns a sparse tensor with underlying dimensions of [batch_size, max_len]
        """
        output_dense = final_state[0]
        mask = tf.not_equal(output_dense, self.stop_token)

        if include_stop_tokens:
            output_dense = tf.concat(1, [output_dense[:, 1:],
                                         tf.ones_like(output_dense[:, 0:1]) *
                                         self.stop_token])
            mask = tf.concat(1, [mask[:, 1:], tf.cast(tf.ones_like(mask[:, 0:1],
                                                                   dtype=tf.int8),
                                                      tf.bool)])

        return sparse_boolean_mask(output_dense, mask)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号