def test_trainable_argument():
x = np.random.random((5, 3))
y = np.random.random((5, 2))
model = Sequential()
model.add(Dense(2, input_dim=3, trainable=False))
model.compile('rmsprop', 'mse')
out = model.predict(x)
model.train_on_batch(x, y)
out_2 = model.predict(x)
assert_allclose(out, out_2)
# test with nesting
input = Input(shape=(3,))
output = model(input)
model = Model(input, output)
model.compile('rmsprop', 'mse')
out = model.predict(x)
model.train_on_batch(x, y)
out_2 = model.predict(x)
assert_allclose(out, out_2)
评论列表
文章目录