cifar10_wrn.py 文件源码

python
阅读 42 收藏 0 点赞 0 评论 0

项目:lemontree 作者: khshim 项目源码 文件源码
def test_testset():
    graph.change_flag(-1)
    test_accuracy = []
    for index in range(test_gen.max_index):        
        confusion_matrix = np.zeros((128, 10)).astype('int32')
        for times in range(10):
            testset = test_gen.get_minibatch(index)  # re-sample again, same data, different preprocessing
            test_output = test_func_output(testset[0])
            test_output = int_to_onehot(test_output, 10)
            confusion_matrix += test_output
        testset = test_gen.get_minibatch(index)
        test_batch_answer = np.argmax(confusion_matrix, axis=-1)
        test_batch_accuracy = np.mean(np.equal(test_batch_answer, testset[1]))
        test_accuracy.append(test_batch_accuracy)
    hist.history['test_accuracy'].append(np.mean(np.asarray(test_accuracy)))

#================Train================#
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号