layers.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:LiTeFlow 作者: petrux 项目源码 文件源码
def finished(self, time, output):
        """Check which sentences are finished.

        Arguments:
          time: a `Tensor` of rank `0D` (i.e. a scalar) with the 0-based value of the
            current step in the loop.
          output: a `Tensor` of rank `2D` and shape `[batch_size, num_classes]` representing
            the current output of the model, i.e. abatch of probability distribution estimations
            over the output classes.

        Returns:
          a `Tensor` of shape `[batch_size]` of `tf.bool` elements, indicating for each
          position if the corresponding sequence has terminated or not. A sequence is
          has terminated if the current step is greater or equal the number of steps allowed
          (defined in the `lengths` input argument) and if the `argmax` over the output
          probability distribution ends up in the class that has id equal to the `EOS` symbol
          (if provided).
        """
        length = time + 1
        finished = tf.greater_equal(length, self._lengths)
        if finished.get_shape().ndims == 0:
            batch = [utils.get_dimension(output, 0)]
            finished = tf.tile([finished], batch)
        if self._EOS is not None:
            ids = tf.cast(tf.argmax(output, axis=-1), tf.int32)
            eos = tf.equal(ids, self._EOS)
            finished = tf.logical_or(finished, eos)
        return finished
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号