def test_learning_phase():
a = Input(shape=(32,), name='input_a')
b = Input(shape=(32,), name='input_b')
a_2 = Dense(16, name='dense_1')(a)
dp = Dropout(0.5, name='dropout')
b_2 = dp(b)
assert dp.uses_learning_phase
assert not a_2._uses_learning_phase
assert b_2._uses_learning_phase
# test merge
m = merge([a_2, b_2], mode='concat')
assert m._uses_learning_phase
# Test recursion
model = Model([a, b], [a_2, b_2])
print(model.input_spec)
assert model.uses_learning_phase
c = Input(shape=(32,), name='input_c')
d = Input(shape=(32,), name='input_d')
c_2, b_2 = model([c, d])
assert c_2._uses_learning_phase
assert b_2._uses_learning_phase
# try actually running graph
fn = K.function(model.inputs + [K.learning_phase()], model.outputs)
input_a_np = np.random.random((10, 32))
input_b_np = np.random.random((10, 32))
fn_outputs_no_dp = fn([input_a_np, input_b_np, 0])
fn_outputs_dp = fn([input_a_np, input_b_np, 1])
# output a: nothing changes
assert fn_outputs_no_dp[0].sum() == fn_outputs_dp[0].sum()
# output b: dropout applied
assert fn_outputs_no_dp[1].sum() != fn_outputs_dp[1].sum()
评论列表
文章目录