def hardmax(logits, name=None):
"""Returns batched one-hot vectors.
The depth index containing the `1` is that of the maximum logit value.
Args:
logits: A batch tensor of logit values.
name: Name to use when creating ops.
Returns:
A batched one-hot tensor.
"""
with ops.name_scope(name, "Hardmax", [logits]):
logits = ops.convert_to_tensor(logits, name="logits")
if logits.get_shape()[-1].value is not None:
depth = logits.get_shape()[-1].value
else:
depth = array_ops.shape(logits)[-1]
return array_ops.one_hot(
math_ops.argmax(logits, -1), depth, dtype=logits.dtype)
attention_wrapper.py 文件源码
python
阅读 21
收藏 0
点赞 0
评论 0
评论列表
文章目录