def finished(self, time, output):
"""Check which sentences are finished.
Arguments:
time: a `Tensor` of rank `0D` (i.e. a scalar) with the 0-based value of the
current step in the loop.
output: a `Tensor` of rank `2D` and shape `[batch_size, num_classes]` representing
the current output of the model, i.e. abatch of probability distribution estimations
over the output classes.
Returns:
a `Tensor` of shape `[batch_size]` of `tf.bool` elements, indicating for each
position if the corresponding sequence has terminated or not. A sequence is
has terminated if the current step is greater or equal the number of steps allowed
(defined in the `lengths` input argument) and if the `argmax` over the output
probability distribution ends up in the class that has id equal to the `EOS` symbol
(if provided).
"""
length = time + 1
finished = tf.greater_equal(length, self._lengths)
if finished.get_shape().ndims == 0:
batch = [utils.get_dimension(output, 0)]
finished = tf.tile([finished], batch)
if self._EOS is not None:
ids = tf.cast(tf.argmax(output, axis=-1), tf.int32)
eos = tf.equal(ids, self._EOS)
finished = tf.logical_or(finished, eos)
return finished
评论列表
文章目录