util.py 文件源码

python
阅读 41 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def one_hot(labels, num_classes, name='one_hot'):
    """Transform numeric labels into onehot_labels.
    Args:
        labels: [batch_size] target labels.
        num_classes: total number of classes.
        scope: Optional scope for op_scope.
    Returns:
        one hot encoding of the labels.
    """
    with tf.op_scope(name):
        batch_size = labels.get_shape()[0]
        indices = tf.expand_dims(tf.range(0, batch_size), 1)
        labels = tf.cast(tf.expand_dims(labels, 1), indices.dtype)
        concated = tf.concat(1, [indices, labels])
        onehot_labels = tf.sparse_to_dense(
            concated, tf.pack([batch_size, num_classes]), 1.0, 0.0)
        onehot_labels.set_shape([batch_size, num_classes])
        return onehot_labels
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号