def one_hot_matrix(tensor_in, num_classes, on_value=1.0, off_value=0.0):
"""Encodes indices from given tensor as one-hot tensor.
TODO(ilblackdragon): Ideally implementation should be
part of TensorFlow with Eigen-native operation.
Args:
tensor_in: Input tensor of shape [N1, N2].
num_classes: Number of classes to expand index into.
on_value: Tensor or float, value to fill-in given index.
off_value: Tensor or float, value to fill-in everything else.
Returns:
Tensor of shape [N1, N2, num_classes] with 1.0 for each id in original
tensor.
"""
return array_ops_.one_hot(
math_ops.cast(tensor_in, dtypes.int64), num_classes, on_value, off_value)
评论列表
文章目录