def test_multi_dqn_input():
input1 = Input(shape=(2, 3))
input2 = Input(shape=(2, 4))
x = Concatenate()([input1, input2])
x = Flatten()(x)
x = Dense(2)(x)
model = Model(inputs=[input1, input2], outputs=x)
memory = SequentialMemory(limit=10, window_length=2)
processor = MultiInputProcessor(nb_inputs=2)
for double_dqn in (True, False):
agent = DQNAgent(model, memory=memory, nb_actions=2, nb_steps_warmup=5, batch_size=4,
processor=processor, enable_double_dqn=double_dqn)
agent.compile('sgd')
agent.fit(MultiInputTestEnv([(3,), (4,)]), nb_steps=10)
评论列表
文章目录