test_nnet.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_confusion_matrix():
    # Defining numpy implementation of confusion matrix
    def numpy_conf_mat(actual, pred):
        order = numpy.union1d(actual, pred)
        colA = numpy.matrix(actual).T
        colP = numpy.matrix(pred).T
        oneHotA = colA.__eq__(order).astype('int64')
        oneHotP = colP.__eq__(order).astype('int64')
        conf_mat = numpy.dot(oneHotA.T, oneHotP)
        conf_mat = numpy.asarray(conf_mat)
        return [conf_mat, order]

    x = tensor.vector()
    y = tensor.vector()
    f = theano.function([x, y], confusion_matrix(x, y))
    list_inputs = [[[0, 1, 2, 1, 0], [0, 0, 2, 1, 2]],
                   [[2, 0, 2, 2, 0, 1], [0, 0, 2, 2, 0, 2]]]

    for case in list_inputs:
        a = numpy.asarray(case[0])
        b = numpy.asarray(case[1])
        out_exp = numpy_conf_mat(a, b)
        outs = f(case[0], case[1])
        for exp, out in zip(out_exp, outs):
            utt.assert_allclose(exp, out)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号