test_metrics.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:nuts-ml 作者: maet3608 项目源码 文件源码
def test_box_pr_curve():
    approx = lambda prc: [(round(p, 2), round(r, 2), s) for p, r, s in prc]

    boxes1 = [(1, 1, 3, 3), (4, 2, 2, 3), (5, 5, 2, 1)]
    boxes2 = [(2, 1, 2, 3), (4, 3, 2, 3)]
    scores1 = [0.5, 0.2, 0.1]
    scores2 = [0.5, 0.2]

    pr_curve = list(nm.box_pr_curve(boxes2, boxes2, scores2))
    expected = [(1.0, 0.5, 0.5), (1.0, 1.0, 0.2)]
    assert pr_curve == expected

    pr_curve = list(nm.box_pr_curve(boxes1, boxes2, scores2))
    expected = [(1.0, 0.33, 0.5), (1.0, 0.67, 0.2)]
    assert approx(pr_curve) == expected

    pr_curve = list(nm.box_pr_curve(boxes2, boxes1, scores1))
    expected = [(1.0, 0.5, 0.5), (1.0, 1.0, 0.2), (0.67, 1.0, 0.1)]
    assert approx(pr_curve) == expected

    pr_curve = list(nm.box_pr_curve(boxes1, [], []))
    assert pr_curve == []

    pr_curve = list(nm.box_pr_curve([], boxes1, scores1))
    assert pr_curve == []
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号