gan_metrics.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def classifier_score(images, classifier_fn, num_batches=1):
    """Classifier score for evaluating a conditional generative model.
    This is based on the Inception Score, but for an arbitrary classifier.
    This technique is described in detail in https://arxiv.org/abs/1606.03498. In
    summary, this function calculates
    exp( E[ KL(p(y|x) || p(y)) ] )
    which captures how different the network's classification prediction is from
    the prior distribution over classes.
    Args:
      images: Images to calculate the classifier score for.
      classifier_fn: A function that takes images and produces logits based on a
        classifier.
      num_batches: Number of batches to split `generated_images` in to in order to
        efficiently run them through the classifier network.
    Returns:
      The classifier score. A floating-point scalar of the same type as the output
      of `classifier_fn`.
    """
    generated_images_list = tf.split(
        images, num_or_size_splits=num_batches)

    # Compute the classifier splits using the memory-efficient `map_fn`.
    logits = tf.map_fn(
        fn=classifier_fn,
        elems=tf.stack(generated_images_list),
        parallel_iterations=1,
        back_prop=False,
        swap_memory=True,
        name='RunClassifier')
    logits = tf.concat(tf.unstack(logits), 0)
    logits.shape.assert_has_rank(2)

    # Use maximum precision for best results.
    logits_dtype = logits.dtype
    if logits_dtype != tf.float64:
        logits = tf.to_double(logits)

    p = tf.nn.softmax(logits)
    q = tf.reduce_mean(p, axis=0)
    kl = _kl_divergence(p, logits, q)
    kl.shape.assert_has_rank(1)
    log_score = tf.reduce_mean(kl)
    final_score = tf.exp(log_score)

    if logits_dtype != tf.float64:
        final_score = tf.cast(final_score, logits_dtype)
    return final_score
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号