test_common.py 文件源码

python
阅读 47 收藏 0 点赞 0 评论 0

项目:keras-rcnn 作者: broadinstitute 项目源码 文件源码
def test_smooth_l1():
    output = keras.backend.variable(
        [[[2.5, 0.0, 0.4, 0.0],
          [0.0, 0.0, 0.0, 0.0],
          [0.0, 2.5, 0.0, 0.4]],
         [[3.5, 0.0, 0.0, 0.0],
          [0.0, 0.4, 0.0, 0.9],
          [0.0, 0.0, 1.5, 0.0]]]
    )

    target = keras.backend.zeros_like(output)

    x = keras_rcnn.backend.smooth_l1(output, target)

    numpy.testing.assert_approx_equal(keras.backend.eval(x), 8.645)

    weights = keras.backend.variable(
        [[2, 1, 1],
         [0, 3, 0]]
    )

    x = keras_rcnn.backend.smooth_l1(output, target, weights=weights)

    numpy.testing.assert_approx_equal(keras.backend.eval(x), 7.695)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号