def decode_sparse(self, include_stop_tokens=True):
dense_symbols, logprobs = self.decode_dense()
mask = tf.not_equal(dense_symbols, self.stop_token)
if include_stop_tokens:
mask = tf.concat(1, [tf.ones_like(mask[:, :1]), mask[:, :-1]])
return sparse_boolean_mask(dense_symbols, mask), logprobs
评论列表
文章目录