def mean_squared_error(output, target, is_mean=False):
"""Return the TensorFlow expression of mean-squre-error of two distributions.
Parameters
----------
output : 2D or 4D tensor.
target : 2D or 4D tensor.
is_mean : boolean, if True, use ``tf.reduce_mean`` to compute the loss of one data, otherwise, use ``tf.reduce_sum`` (default).
References
------------
- `Wiki Mean Squared Error <https://en.wikipedia.org/wiki/Mean_squared_error>`_
"""
with tf.name_scope("mean_squared_error_loss"):
if output.get_shape().ndims == 2: # [batch_size, n_feature]
if is_mean:
mse = tf.reduce_mean(tf.reduce_mean(tf.squared_difference(output, target), 1))
else:
mse = tf.reduce_mean(tf.reduce_sum(tf.squared_difference(output, target), 1))
elif output.get_shape().ndims == 4: # [batch_size, w, h, c]
if is_mean:
mse = tf.reduce_mean(tf.reduce_mean(tf.squared_difference(output, target), [1, 2, 3]))
else:
mse = tf.reduce_mean(tf.reduce_sum(tf.squared_difference(output, target), [1, 2, 3]))
return mse
评论列表
文章目录