test_objectives.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:keraflow 作者: ipod825 项目源码 文件源码
def test_accuracy():
    def cat_acc(y_pred, y_true):
        return np.expand_dims(np.equal(np.argmax(y_pred, axis=-1), np.argmax(y_true, axis=-1)), -1),

    objectives_test(objectives.accuracy,
                    cat_acc,
                    np_pred=[[0,0,.9], [0,.9,0], [.9,0,0]],
                    np_true=[[0,0,1], [0,0,1], [0,0,1]])

    def bi_acc(y_pred, y_true):
        return np.equal(np.round(y_pred), y_true)

    objectives_test(objectives.accuracy,
                    bi_acc,
                    np_pred=[[0], [0.6], [0.7]],
                    np_true=[[0], [1], [1]])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号