test_models_keras.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:foolbox 作者: bethgelab 项目源码 文件源码
def test_keras_model_probs(num_classes):
    bounds = (0, 255)
    channels = num_classes

    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=DeprecationWarning)
        inputs = Input(shape=(5, 5, channels))
        logits = GlobalAveragePooling2D(
            data_format='channels_last')(inputs)
        probs = Activation(softmax)(logits)

        model1 = KerasModel(
            Model(inputs=inputs, outputs=logits),
            bounds=bounds,
            predicts='logits')

        model2 = KerasModel(
            Model(inputs=inputs, outputs=probs),
            bounds=bounds,
            predicts='probabilities')

        model3 = KerasModel(
            Model(inputs=inputs, outputs=probs),
            bounds=bounds,
            predicts='probs')

    np.random.seed(22)
    test_images = np.random.rand(2, 5, 5, channels).astype(np.float32)

    p1 = model1.batch_predictions(test_images)
    p2 = model2.batch_predictions(test_images)
    p3 = model3.batch_predictions(test_images)

    assert p1.shape == p2.shape == p3.shape == (2, num_classes)

    np.testing.assert_array_almost_equal(
        p1 - p1.max(),
        p2 - p2.max(),
        decimal=1)

    np.testing.assert_array_almost_equal(
        p2 - p2.max(),
        p3 - p3.max(),
        decimal=5)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号