def test_trainable_weights():
a = Input(shape=(2,))
b = Dense(1)(a)
model = Model(a, b)
weights = model.weights
assert model.trainable_weights == weights
assert model.non_trainable_weights == []
model.trainable = False
assert model.trainable_weights == []
assert model.non_trainable_weights == weights
model.trainable = True
assert model.trainable_weights == weights
assert model.non_trainable_weights == []
model.layers[1].trainable = False
assert model.trainable_weights == []
assert model.non_trainable_weights == weights
# sequential model
model = Sequential()
model.add(Dense(1, input_dim=2))
weights = model.weights
assert model.trainable_weights == weights
assert model.non_trainable_weights == []
model.trainable = False
assert model.trainable_weights == []
assert model.non_trainable_weights == weights
model.trainable = True
assert model.trainable_weights == weights
assert model.non_trainable_weights == []
model.layers[0].trainable = False
assert model.trainable_weights == []
assert model.non_trainable_weights == weights
评论列表
文章目录