AI_multi_GPU_rollout_v3.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:Renju-AI 作者: yao62995 项目源码 文件源码
def one_hot_encoding(labels, num_classes, scope=None):
    """Transform numeric labels into onehot_labels.

    Args:
      labels: [batch_size] target labels.
      num_classes: total number of classes.
      scope: Optional scope for op_scope.
    Returns:
      one hot encoding of the labels.
    """
    with tf.op_scope([labels], scope, 'OneHotEncoding'):
        batch_size = labels.get_shape()[0]
        indices = tf.expand_dims(tf.range(0, batch_size), 1)
        labels = tf.cast(tf.expand_dims(labels, 1), indices.dtype)
        concated = tf.concat(1, [indices, labels])
        onehot_labels = tf.sparse_to_dense(
            concated, tf.pack([batch_size, num_classes]), 1.0, 0.0)
        onehot_labels.set_shape([batch_size, num_classes])
        return onehot_labels
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号