cartpole.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:learning-tf 作者: unixpickle 项目源码 文件源码
def main():
    """
    Train a policy on the CartPole-v0 environment.
    """
    observations = tf.placeholder(tf.float32, shape=[None, 4])
    out_probs = tf.nn.softmax(policy(observations))

    # Selected actions (one-hot vectors) and cumulative
    # episode rewards for those actions.
    actions = tf.placeholder(tf.float32, shape=[None, 2])
    goodnesses = tf.placeholder(tf.float32, shape=[None, 1])

    loss = -tf.tensordot(tf.log(out_probs), actions*goodnesses, axes=2)
    loss /= tf.cast(tf.shape(actions)[0], tf.float32)
    opt = tf.train.AdamOptimizer(learning_rate=1e-2)
    minimize = opt.minimize(loss)

    env = gym.make('CartPole-v0')
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        while True:
            obs, acts, rews, mean_rew = rollouts(env, sess, observations,
                                                 out_probs, 2000)
            loss_args = {
                observations: obs,
                actions: acts,
                goodnesses: rews
            }
            print('mean_reward=%f' % (mean_rew,))
            sess.run(minimize, feed_dict=loss_args)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号