distributions.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:wide-deep-cnn 作者: DaniUPC 项目源码 文件源码
def optimized_loss(self, targets, logits):
        """ Function that computes the loss of a mixture density network
        in a way that it handles underflow and overflow and avoids unstable
        behaviors """
        # Obtain parameters
        mixings, sigma, mean = self.logits_to_params(logits)
        output_size = tf.cast(tf.shape(targets)[1], tf.float32)
        variance = tf.square(sigma)
        # Convert expressions into exponent-based terms
        mixings_exp = tf.log(mixings)
        # By properties of logarithm we can simplify the original expression
        # log(x/y) = log(x) - log(y), log(xy) = log(x) + log(y), log(1) = 0
        sqrt_exp = - output_size * (0.5 * tf.log(2*np.pi) + tf.log(sigma))
        gaussian_exp = -tf.divide(tf.square(targets - mean), 2 * variance)
        exponent = mixings_exp + sqrt_exp + gaussian_exp
        # Use optimized logsumexp function to control underflow/overflow
        return tf.reduce_logsumexp(exponent, axis=1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号