pg.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:rl 作者: Shmuma 项目源码 文件源码
def make_model(state_shape, n_actions):
    in_t = Input(shape=(HISTORY_STEPS,) + state_shape, name='input')
    action_t = Input(shape=(1,), dtype='int32', name='action')
    advantage_t = Input(shape=(1,), name='advantage')

    fl_t = Flatten(name='flat')(in_t)
    l1_t = Dense(SIMPLE_L1_SIZE, activation='relu', name='l1')(fl_t)
    l2_t = Dense(SIMPLE_L2_SIZE, activation='relu', name='l2')(l1_t)
    policy_t = Dense(n_actions, name='policy', activation='softmax')(l2_t)

    def loss_func(args):
        p_t, act_t, adv_t = args
        oh_t = K.one_hot(act_t, n_actions)
        oh_t = K.squeeze(oh_t, 1)
        p_oh_t = K.log(1e-6 + K.sum(oh_t * p_t, axis=-1, keepdims=True))
        res_t = adv_t * p_oh_t
        return -res_t

    loss_t = Lambda(loss_func, output_shape=(1,), name='loss')([policy_t, action_t, advantage_t])

    return Model(input=[in_t, action_t, advantage_t], output=[policy_t, loss_t])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号