ops.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:sText2Image 作者: elliottwu 项目源码 文件源码
def kl_divergence(p, q):
    tf.assert_rank(p,2)
    tf.assert_rank(q,2)

    p_shape = tf.shape(p)
    q_shape = tf.shape(q)
    tf.assert_equal(p_shape, q_shape)

    # normalize sum to 1
    p_ = tf.divide(p, tf.tile(tf.expand_dims(tf.reduce_sum(p,axis=1), 1), [1,p_shape[1]]))
    q_ = tf.divide(q, tf.tile(tf.expand_dims(tf.reduce_sum(q,axis=1), 1), [1,p_shape[1]]))

    return tf.reduce_sum(tf.multiply(p_, tf.log(tf.divide(p_, q_))), axis=1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号