metrics.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:tflearn 作者: tflearn 项目源码 文件源码
def weighted_r2_op(predictions, targets, inputs):
    """ weighted_r2_op.

    An op that calculates the standard error.

    Examples:
        ```python
        input_data = placeholder(shape=[None, 784])
        y_pred = my_network(input_data) # Apply some ops
        y_true = placeholder(shape=[None, 10]) # Labels
        stderr_op = weighted_r2_op(y_pred, y_true, input_data)

        # Calculate standard error by feeding data X and labels Y
        std_error = sess.run(stderr_op, feed_dict={input_data: X, y_true: Y})
Arguments:
    predictions: `Tensor`.
    targets: `Tensor`.
    inputs: `Tensor`.

Returns:
    `Float`. The standard error.

"""
with tf.name_scope('WeightedStandardError'):
    if hasattr(inputs, '__len__'):
        inputs = tf.add_n(inputs)
    if inputs.get_shape().as_list() != targets.get_shape().as_list():
        raise Exception("Weighted R2 metric requires Inputs and Targets to "
                        "have same shape.")
    a = tf.reduce_sum(tf.square(predictions - inputs))
    b = tf.reduce_sum(tf.square(targets - inputs))
    return tf.divide(a, b)

```

评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号