network.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:ReinforcementLearning 作者: persistforever 项目源码 文件源码
def train_one_batch(self):
        self.actions = tensor.vector(name='actions', dtype='int64')
        self.y = tensor.vector(name='y', dtype=theano.config.floatX)
        cost = self.output_vector[self.actions].sum() / self.actions.shape[0]
        coef = (self.y - self.output_vector[self.actions]).sum() / self.actions.shape[0]
        grads = tensor.grad(cost, wrt=self.params.values())
        grads = [coef*t for t in grads]

        lr = tensor.scalar(name='lr')
        f_update = self._adadelta(lr, self.params, grads)

        def update_function(states, actions, y, yita):
            f_update(numpy.array(yita, dtype=theano.config.floatX))
            return

        return update_function
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号