ram.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:ram 作者: amasky 项目源码 文件源码
def __call__(self, x, t, train=True):
        x = chainer.Variable(self.xp.asarray(x), volatile=not train)
        t = chainer.Variable(self.xp.asarray(t), volatile=not train)
        bs = x.data.shape[0] # batch size
        self.clear(bs, train)

        # init mean location
        l = np.random.uniform(-1, 1, size=(bs,2)).astype(np.float32)
        l = chainer.Variable(self.xp.asarray(l), volatile=not train)

        # forward n_steps time
        sum_ln_pi = 0
        self.forward(x, train, action=False, init_l=l)
        for i in range(1, self.n_steps):
            action = True if (i == self.n_steps - 1) else False
            l, ln_pi, y, b = self.forward(x, train, action)
            if train: sum_ln_pi += ln_pi

        # loss with softmax cross entropy
        self.loss_action = F.softmax_cross_entropy(y, t)
        self.loss = self.loss_action
        self.accuracy = F.accuracy(y, t)

        if train:
            # reward
            conditions = self.xp.argmax(y.data, axis=1) == t.data
            r = self.xp.where(conditions, 1., 0.).astype(np.float32)

            # squared error between reward and baseline
            self.loss_base = F.mean_squared_error(r, b)
            self.loss += self.loss_base

            # loss with reinforce rule
            mean_ln_pi = sum_ln_pi / (self.n_steps - 1)
            self.loss_reinforce = F.sum(-mean_ln_pi * (r-b))/bs
            self.loss += self.loss_reinforce

        return self.loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号