def test_shared_batchnorm():
'''Test that a BN layer can be shared
across different data streams.
'''
# Test single layer reuse
bn = normalization.BatchNormalization(input_shape=(10,), mode=0)
x1 = Input(shape=(10,))
bn(x1)
x2 = Input(shape=(10,))
y2 = bn(x2)
x = np.random.normal(loc=5.0, scale=10.0, size=(2, 10))
model = Model(x2, y2)
assert len(model.updates) == 2
model.compile('sgd', 'mse')
model.train_on_batch(x, x)
# Test model-level reuse
x3 = Input(shape=(10,))
y3 = model(x3)
new_model = Model(x3, y3)
assert len(model.updates) == 2
new_model.compile('sgd', 'mse')
new_model.train_on_batch(x, x)
评论列表
文章目录