mnist.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:convnet-nolearn 作者: jcouvy 项目源码 文件源码
def display_confusion_matrix(test_data, test_labels, save=False):
    """
    Plot a matrix representing the choices made by the network
    on a testing batch.
    X axis are the predicted values,
    Y axis are the expected values.

    If the flag save is set to True, the output will be saved
    in a .png image.
    """
    expected = test_labels
    predicted = mnist.predict(test_data)
    cm = confusion_matrix(expected, predicted)
    plt.matshow(cm)
    plt.title('Confusion matrix')
    plt.colorbar()
    plt.ylabel('Expected label')
    plt.xlabel('Predicted label')
    plt.show()
    if save is True:
        plt.savefig("../results/mnist/confusion_matrix.png")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号