test_dqn.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:keras-rl 作者: matthiasplappert 项目源码 文件源码
def test_multi_continuous_dqn_input():
    nb_actions = 2

    V_input1 = Input(shape=(2, 3))
    V_input2 = Input(shape=(2, 4))
    x = Concatenate()([V_input1, V_input2])
    x = Flatten()(x)
    x = Dense(1)(x)
    V_model = Model(inputs=[V_input1, V_input2], outputs=x)

    mu_input1 = Input(shape=(2, 3))
    mu_input2 = Input(shape=(2, 4))
    x = Concatenate()([mu_input1, mu_input2])
    x = Flatten()(x)
    x = Dense(nb_actions)(x)
    mu_model = Model(inputs=[mu_input1, mu_input2], outputs=x)

    L_input1 = Input(shape=(2, 3))
    L_input2 = Input(shape=(2, 4))
    L_input_action = Input(shape=(nb_actions,))
    x = Concatenate()([L_input1, L_input2])
    x = Concatenate()([Flatten()(x), L_input_action])
    x = Dense(((nb_actions * nb_actions + nb_actions) // 2))(x)
    L_model = Model(inputs=[L_input_action, L_input1, L_input2], outputs=x)

    memory = SequentialMemory(limit=10, window_length=2)
    processor = MultiInputProcessor(nb_inputs=2)
    agent = NAFAgent(nb_actions=nb_actions, V_model=V_model, L_model=L_model, mu_model=mu_model,
                     memory=memory, nb_steps_warmup=5, batch_size=4, processor=processor)
    agent.compile('sgd')
    agent.fit(MultiInputTestEnv([(3,), (4,)]), nb_steps=10)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号