def identify_saliency(grads):
"""Identify top k saliency scores.
Args.
grads: gradient of the entropy wrt features
Trick.
use tf.nn.top_k ops to extract position indices
"""
M = tf.sqrt(tf.reduce_sum(tf.square(grads),3)+1e-8)
top_k_values, top_k_idxs = tf.nn.top_k(ops.flatten(M), N_PATCHES, sorted=False)
# shuffle patch indices for batch normalization
top_k_idxs = tf.random_shuffle(tf.transpose(top_k_idxs))
top_k_idxs = tf.transpose(top_k_idxs)
return top_k_values, top_k_idxs, M
评论列表
文章目录