def test_multi_continuous_dqn_input():
nb_actions = 2
V_input1 = Input(shape=(2, 3))
V_input2 = Input(shape=(2, 4))
x = Concatenate()([V_input1, V_input2])
x = Flatten()(x)
x = Dense(1)(x)
V_model = Model(inputs=[V_input1, V_input2], outputs=x)
mu_input1 = Input(shape=(2, 3))
mu_input2 = Input(shape=(2, 4))
x = Concatenate()([mu_input1, mu_input2])
x = Flatten()(x)
x = Dense(nb_actions)(x)
mu_model = Model(inputs=[mu_input1, mu_input2], outputs=x)
L_input1 = Input(shape=(2, 3))
L_input2 = Input(shape=(2, 4))
L_input_action = Input(shape=(nb_actions,))
x = Concatenate()([L_input1, L_input2])
x = Concatenate()([Flatten()(x), L_input_action])
x = Dense(((nb_actions * nb_actions + nb_actions) // 2))(x)
L_model = Model(inputs=[L_input_action, L_input1, L_input2], outputs=x)
memory = SequentialMemory(limit=10, window_length=2)
processor = MultiInputProcessor(nb_inputs=2)
agent = NAFAgent(nb_actions=nb_actions, V_model=V_model, L_model=L_model, mu_model=mu_model,
memory=memory, nb_steps_warmup=5, batch_size=4, processor=processor)
agent.compile('sgd')
agent.fit(MultiInputTestEnv([(3,), (4,)]), nb_steps=10)
评论列表
文章目录