test_normalization.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:keras-customized 作者: ambrite 项目源码 文件源码
def test_batchnorm_mode_0_or_2():
    for mode in [0, 2]:
        model = Sequential()
        norm_m0 = normalization.BatchNormalization(mode=mode, input_shape=(10,), momentum=0.8)
        model.add(norm_m0)
        model.compile(loss='mse', optimizer='sgd')

        # centered on 5.0, variance 10.0
        X = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10))
        model.fit(X, X, nb_epoch=4, verbose=0)
        out = model.predict(X)
        out -= K.eval(norm_m0.beta)
        out /= K.eval(norm_m0.gamma)

        assert_allclose(out.mean(), 0.0, atol=1e-1)
        assert_allclose(out.std(), 1.0, atol=1e-1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号