ewc_mnist.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:chainer-EWC 作者: okdshin 项目源码 文件源码
def __call__(self, *args):
        x = args[:-1]
        t = args[-1]
        self.y = None
        self.loss = None
        self.accuracy = None
        self.y = self.predictor(*x)
        self.loss = F.softmax_cross_entropy(self.y, t)

        if self.stored_variable_list is not None and \
                self.fisher_list is not None:  # i.e. Stored
            for i in range(len(self.variable_list)):
                self.loss += self.lam/2. * F.sum(
                        self.fisher_list[i] *
                        F.square(self.variable_list[i][1] -
                                 self.stored_variable_list[i]))
        reporter.report({'loss': self.loss}, self)
        if self.compute_accuracy:
            self.accuracy = F.accuracy(self.y, t)
            reporter.report({'accuracy': self.accuracy}, self)
        return self.loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号