bp_mll.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:bp-mll-tensorflow 作者: vanHavel 项目源码 文件源码
def bp_mll_loss(y_true, y_pred):

    # get true and false labels
    shape = tf.shape(y_true)
    y_i = tf.equal(y_true, tf.ones(shape))
    y_i_bar = tf.not_equal(y_true, tf.ones(shape))

    # get indices to check
    truth_matrix = tf.to_float(pairwise_and(y_i, y_i_bar))

    # calculate all exp'd differences
    sub_matrix = pairwise_sub(y_pred, y_pred)
    exp_matrix = tf.exp(tf.negative(sub_matrix))

    # check which differences to consider and sum them
    sparse_matrix = tf.multiply(exp_matrix, truth_matrix)
    sums = tf.reduce_sum(sparse_matrix, axis=[1,2])

    # get normalizing terms and apply them
    y_i_sizes = tf.reduce_sum(tf.to_float(y_i), axis=1)
    y_i_bar_sizes = tf.reduce_sum(tf.to_float(y_i_bar), axis=1)
    normalizers = tf.multiply(y_i_sizes, y_i_bar_sizes)
    results = tf.divide(sums, normalizers)

    # sum over samples
    return tf.reduce_sum(results)

# compute pairwise differences between elements of the tensors a and b
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号