policy_gradiant.py 文件源码

python
阅读 14 收藏 0 点赞 0 评论 0

项目:a3c 作者: hercky 项目源码 文件源码
def __init__(self, ob_space, action_space, **usercfg):
        """
        Initialize your agent's parameters
        """
        nO = ob_space.shape[0]
        nA = action_space.n
        # Here are all the algorithm parameters. You can modify them by passing in keyword args
        self.config = dict(episode_max_length=100, timesteps_per_batch=10000, n_iter=100, 
            gamma=1.0, stepsize=0.05, nhid=20)
        self.config.update(usercfg)

        # Symbolic variables for observation, action, and advantage
        # These variables stack the results from many timesteps--the first dimension is the timestep
        ob_no = T.fmatrix() # Observation
        a_n = T.ivector() # Discrete action 
        adv_n = T.fvector() # Advantage


        def shared(arr):
            return theano.shared(arr.astype('float64'))

        # Create weights of neural network with one hidden layer
        W0 = shared(np.random.randn(nO,self.config['nhid'])/np.sqrt(nO))
        b0 = shared(np.zeros(self.config['nhid']))
        W1 = shared(1e-4*np.random.randn(self.config['nhid'],nA))
        b1 = shared(np.zeros(nA))
        params = [W0, b0, W1, b1]

        # Action probabilities
        prob_na = T.nnet.softmax(T.tanh(ob_no.dot(W0)+b0[None,:]).dot(W1) + b1[None,:])
        N = ob_no.shape[0]

        # Loss function that we'll differentiate to get the policy gradient
        # Note that we've divided by the total number of timesteps
        loss = T.log(prob_na[T.arange(N), a_n]).dot(adv_n) / N
        stepsize = T.fscalar()
        grads = T.grad(loss, params)
        # Perform parameter updates.
        # I find that sgd doesn't work well
        # updates = sgd_updates(grads, params, stepsize)
        updates = rmsprop_updates(grads, params, stepsize)
        self.pg_update = theano.function([ob_no, a_n, adv_n, stepsize], [], updates=updates, allow_input_downcast=True)
        self.compute_prob = theano.function([ob_no], prob_na, allow_input_downcast=True)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号