def classifier_score(images, classifier_fn, num_batches=1):
"""Classifier score for evaluating a conditional generative model.
This is based on the Inception Score, but for an arbitrary classifier.
This technique is described in detail in https://arxiv.org/abs/1606.03498. In
summary, this function calculates
exp( E[ KL(p(y|x) || p(y)) ] )
which captures how different the network's classification prediction is from
the prior distribution over classes.
Args:
images: Images to calculate the classifier score for.
classifier_fn: A function that takes images and produces logits based on a
classifier.
num_batches: Number of batches to split `generated_images` in to in order to
efficiently run them through the classifier network.
Returns:
The classifier score. A floating-point scalar of the same type as the output
of `classifier_fn`.
"""
generated_images_list = tf.split(
images, num_or_size_splits=num_batches)
# Compute the classifier splits using the memory-efficient `map_fn`.
logits = tf.map_fn(
fn=classifier_fn,
elems=tf.stack(generated_images_list),
parallel_iterations=1,
back_prop=False,
swap_memory=True,
name='RunClassifier')
logits = tf.concat(tf.unstack(logits), 0)
logits.shape.assert_has_rank(2)
# Use maximum precision for best results.
logits_dtype = logits.dtype
if logits_dtype != tf.float64:
logits = tf.to_double(logits)
p = tf.nn.softmax(logits)
q = tf.reduce_mean(p, axis=0)
kl = _kl_divergence(p, logits, q)
kl.shape.assert_has_rank(1)
log_score = tf.reduce_mean(kl)
final_score = tf.exp(log_score)
if logits_dtype != tf.float64:
final_score = tf.cast(final_score, logits_dtype)
return final_score
评论列表
文章目录