def unwrap_output_sparse(self, final_state, include_stop_tokens=True):
"""
Retreive the beam search output from the final state.
Returns a sparse tensor with underlying dimensions of [batch_size, max_len]
"""
output_dense = final_state[0]
mask = tf.not_equal(output_dense, self.stop_token)
if include_stop_tokens:
output_dense = tf.concat(1, [output_dense[:, 1:],
tf.ones_like(output_dense[:, 0:1]) *
self.stop_token])
mask = tf.concat(1, [mask[:, 1:], tf.cast(tf.ones_like(mask[:, 0:1],
dtype=tf.int8),
tf.bool)])
return sparse_boolean_mask(output_dense, mask)
评论列表
文章目录