def sample(self, time, outputs, state, name=None):
"""sample for GreedyEmbeddingHelper."""
del time, state # unused by sample_fn
# Outputs are logits, use argmax to get the most probable id
if not isinstance(outputs, ops.Tensor):
raise TypeError("Expected outputs to be a single Tensor, got: %s" %
type(outputs))
sample_ids = math_ops.cast(
math_ops.argmax(outputs, axis=-1), dtypes.int32)
return sample_ids
评论列表
文章目录