test_topology.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:keras 作者: NVIDIA 项目源码 文件源码
def test_trainable_weights():
    a = Input(shape=(2,))
    b = Dense(1)(a)
    model = Model(a, b)

    weights = model.weights
    assert model.trainable_weights == weights
    assert model.non_trainable_weights == []

    model.trainable = False
    assert model.trainable_weights == []
    assert model.non_trainable_weights == weights

    model.trainable = True
    assert model.trainable_weights == weights
    assert model.non_trainable_weights == []

    model.layers[1].trainable = False
    assert model.trainable_weights == []
    assert model.non_trainable_weights == weights

    # sequential model
    model = Sequential()
    model.add(Dense(1, input_dim=2))
    weights = model.weights

    assert model.trainable_weights == weights
    assert model.non_trainable_weights == []

    model.trainable = False
    assert model.trainable_weights == []
    assert model.non_trainable_weights == weights

    model.trainable = True
    assert model.trainable_weights == weights
    assert model.non_trainable_weights == []

    model.layers[0].trainable = False
    assert model.trainable_weights == []
    assert model.non_trainable_weights == weights
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号