test_dqn.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:keras-rl 作者: matthiasplappert 项目源码 文件源码
def test_multi_dqn_input():
    input1 = Input(shape=(2, 3))
    input2 = Input(shape=(2, 4))
    x = Concatenate()([input1, input2])
    x = Flatten()(x)
    x = Dense(2)(x)
    model = Model(inputs=[input1, input2], outputs=x)

    memory = SequentialMemory(limit=10, window_length=2)
    processor = MultiInputProcessor(nb_inputs=2)
    for double_dqn in (True, False):
        agent = DQNAgent(model, memory=memory, nb_actions=2, nb_steps_warmup=5, batch_size=4,
                         processor=processor, enable_double_dqn=double_dqn)
        agent.compile('sgd')
        agent.fit(MultiInputTestEnv([(3,), (4,)]), nb_steps=10)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号