helper.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:GMAN 作者: iDurugkar 项目源码 文件源码
def mix_prediction(losses, lam=0., mean_typ='arithmetic', weight_typ='normal', sign=-1., sf=1e-3):
    # losses is shape (# of discriminators x batch_size)
    # output is scalar

    tf.assert_non_negative(lam)
    assert mean_typ in ['arithmetic','geometric','harmonic']
    assert weight_typ in ['normal','log']
    assert sign == 1. or sign == -1.
    assert sf > 0.

    if lam == 0.:
        weights = tf.ones_like(losses)
    else:
        if weight_typ == 'log':
            weights = tf.pow(losses, lam)
        else:
            weights = tf.exp(lam * losses)

    if mean_typ == 'arithmetic':
        loss = weighted_arithmetic(weights, losses)
    elif mean_typ == 'geometric':
        log_losses = tf.log(sign*losses)
        loss = sign*tf.exp(weighted_arithmetic(weights, log_losses))
    else:
        mn = tf.reduce_min(losses) - sf
        inv_losses = tf.reciprocal(losses-mn)
        loss = mn + tf.reciprocal(weighted_arithmetic(weights, inv_losses))

    return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号