value_functions.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:rl_algorithms 作者: DanielTakeshi 项目源码 文件源码
def __init__(self, session, ob_dim=None, n_epochs=20, stepsize=1e-3):
        """ The network gets constructed upon initialization so future calls to
        self.fit will remember this. 

        Right now we assume a preprocessing which results ob_dim*2+1 dimensions,
        and we assume a fixed neural network architecture (input-50-50-1, fully
        connected with tanh nonlineariites), which we should probably change.

        The number of outputs is one, so that ypreds_n is the predicted vector
        of state values, to be compared against ytargs_n. Since ytargs_n is of
        shape (n,), we need to apply a "squeeze" on the final predictions, which
        would otherwise be of shape (n,1). Bleh.
        """
        # Value function V(s_t) (or b(s_t)), parameterized as a neural network.
        self.ob_no = tf.placeholder(shape=[None, ob_dim*2+1], name="nnvf_ob", dtype=tf.float32)
        self.h1 = layers.fully_connected(self.ob_no, 
                num_outputs=50,
                weights_initializer=layers.xavier_initializer(uniform=True),
                activation_fn=tf.nn.tanh)
        self.h2 = layers.fully_connected(self.h1,
                num_outputs=50,
                weights_initializer=layers.xavier_initializer(uniform=True),
                activation_fn=tf.nn.tanh)
        self.ypreds_n = layers.fully_connected(self.h2,
                num_outputs=1,
                weights_initializer=layers.xavier_initializer(uniform=True),
                activation_fn=None)
        self.ypreds_n = tf.reshape(self.ypreds_n, [-1]) # (?,1) --> (?,). =)

        # Form the loss function, which is the simple (mean) L2 error.
        self.n_epochs = n_epochs
        self.lrate    = stepsize
        self.ytargs_n = tf.placeholder(shape=[None], name="nnvf_y", dtype=tf.float32)
        self.l2_error = tf.reduce_mean(tf.square(self.ypreds_n - self.ytargs_n))
        self.fit_op   = tf.train.AdamOptimizer(self.lrate).minimize(self.l2_error)
        self.sess     = session
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号