def kl_divergence(p, q):
tf.assert_rank(p,2)
tf.assert_rank(q,2)
p_shape = tf.shape(p)
q_shape = tf.shape(q)
tf.assert_equal(p_shape, q_shape)
# normalize sum to 1
p_ = tf.divide(p, tf.tile(tf.expand_dims(tf.reduce_sum(p,axis=1), 1), [1,p_shape[1]]))
q_ = tf.divide(q, tf.tile(tf.expand_dims(tf.reduce_sum(q,axis=1), 1), [1,p_shape[1]]))
return tf.reduce_sum(tf.multiply(p_, tf.log(tf.divide(p_, q_))), axis=1)
评论列表
文章目录