test_models_lasagne.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:foolbox 作者: bethgelab 项目源码 文件源码
def test_lasagne_model(num_classes):
    bounds = (0, 255)
    channels = num_classes

    def mean_brightness_net(images):
        logits = GlobalPoolLayer(images)
        return logits

    images_var = T.tensor4('images', dtype='float32')
    images = InputLayer((None, channels, 5, 5), images_var)
    logits = mean_brightness_net(images)

    model = LasagneModel(
        images,
        logits,
        bounds=bounds)

    test_images = np.random.rand(2, channels, 5, 5).astype(np.float32)
    test_label = 7

    assert model.batch_predictions(test_images).shape \
        == (2, num_classes)

    test_logits = model.predictions(test_images[0])
    assert test_logits.shape == (num_classes,)

    test_gradient = model.gradient(test_images[0], test_label)
    assert test_gradient.shape == test_images[0].shape

    np.testing.assert_almost_equal(
        model.predictions_and_gradient(test_images[0], test_label)[0],
        test_logits)
    np.testing.assert_almost_equal(
        model.predictions_and_gradient(test_images[0], test_label)[1],
        test_gradient)

    assert model.num_classes() == num_classes
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号