def mix_prediction(losses, lam=0., mean_typ='arithmetic', weight_typ='normal', sign=-1., sf=1e-3):
# losses is shape (# of discriminators x batch_size)
# output is scalar
tf.assert_non_negative(lam)
assert mean_typ in ['arithmetic','geometric','harmonic']
assert weight_typ in ['normal','log']
assert sign == 1. or sign == -1.
assert sf > 0.
if lam == 0.:
weights = tf.ones_like(losses)
else:
if weight_typ == 'log':
weights = tf.pow(losses, lam)
else:
weights = tf.exp(lam * losses)
if mean_typ == 'arithmetic':
loss = weighted_arithmetic(weights, losses)
elif mean_typ == 'geometric':
log_losses = tf.log(sign*losses)
loss = sign*tf.exp(weighted_arithmetic(weights, log_losses))
else:
mn = tf.reduce_min(losses) - sf
inv_losses = tf.reciprocal(losses-mn)
loss = mn + tf.reciprocal(weighted_arithmetic(weights, inv_losses))
return loss
评论列表
文章目录