def optimized_loss(self, targets, logits):
""" Function that computes the loss of a mixture density network
in a way that it handles underflow and overflow and avoids unstable
behaviors """
# Obtain parameters
mixings, sigma, mean = self.logits_to_params(logits)
output_size = tf.cast(tf.shape(targets)[1], tf.float32)
variance = tf.square(sigma)
# Convert expressions into exponent-based terms
mixings_exp = tf.log(mixings)
# By properties of logarithm we can simplify the original expression
# log(x/y) = log(x) - log(y), log(xy) = log(x) + log(y), log(1) = 0
sqrt_exp = - output_size * (0.5 * tf.log(2*np.pi) + tf.log(sigma))
gaussian_exp = -tf.divide(tf.square(targets - mean), 2 * variance)
exponent = mixings_exp + sqrt_exp + gaussian_exp
# Use optimized logsumexp function to control underflow/overflow
return tf.reduce_logsumexp(exponent, axis=1)
评论列表
文章目录