cost.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:deepsleepnet 作者: akaraspt 项目源码 文件源码
def mean_squared_error(output, target, is_mean=False):
    """Return the TensorFlow expression of mean-squre-error of two distributions.

    Parameters
    ----------
    output : 2D or 4D tensor.
    target : 2D or 4D tensor.
    is_mean : boolean, if True, use ``tf.reduce_mean`` to compute the loss of one data, otherwise, use ``tf.reduce_sum`` (default).

    References
    ------------
    - `Wiki Mean Squared Error <https://en.wikipedia.org/wiki/Mean_squared_error>`_
    """
    with tf.name_scope("mean_squared_error_loss"):
        if output.get_shape().ndims == 2:   # [batch_size, n_feature]
            if is_mean:
                mse = tf.reduce_mean(tf.reduce_mean(tf.squared_difference(output, target), 1))
            else:
                mse = tf.reduce_mean(tf.reduce_sum(tf.squared_difference(output, target), 1))
        elif output.get_shape().ndims == 4: # [batch_size, w, h, c]
            if is_mean:
                mse = tf.reduce_mean(tf.reduce_mean(tf.squared_difference(output, target), [1, 2, 3]))
            else:
                mse = tf.reduce_mean(tf.reduce_sum(tf.squared_difference(output, target), [1, 2, 3]))
        return mse
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号