test_topology.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:keras 作者: NVIDIA 项目源码 文件源码
def test_learning_phase():
    a = Input(shape=(32,), name='input_a')
    b = Input(shape=(32,), name='input_b')

    a_2 = Dense(16, name='dense_1')(a)
    dp = Dropout(0.5, name='dropout')
    b_2 = dp(b)

    assert dp.uses_learning_phase

    assert not a_2._uses_learning_phase
    assert b_2._uses_learning_phase

    # test merge
    m = merge([a_2, b_2], mode='concat')
    assert m._uses_learning_phase

    # Test recursion
    model = Model([a, b], [a_2, b_2])
    print(model.input_spec)
    assert model.uses_learning_phase

    c = Input(shape=(32,), name='input_c')
    d = Input(shape=(32,), name='input_d')

    c_2, b_2 = model([c, d])
    assert c_2._uses_learning_phase
    assert b_2._uses_learning_phase

    # try actually running graph
    fn = K.function(model.inputs + [K.learning_phase()], model.outputs)
    input_a_np = np.random.random((10, 32))
    input_b_np = np.random.random((10, 32))
    fn_outputs_no_dp = fn([input_a_np, input_b_np, 0])
    fn_outputs_dp = fn([input_a_np, input_b_np, 1])
    # output a: nothing changes
    assert fn_outputs_no_dp[0].sum() == fn_outputs_dp[0].sum()
    # output b: dropout applied
    assert fn_outputs_no_dp[1].sum() != fn_outputs_dp[1].sum()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号