ewc_mnist.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:chainer-EWC 作者: okdshin 项目源码 文件源码
def compute_fisher(self, dataset):
        fisher_accum_list = [
                np.zeros(var[1].shape) for var in self.variable_list]

        for _ in range(self.num_samples):
            x, _ = dataset[np.random.randint(len(dataset))]
            y = self.predictor(np.array([x]))
            prob_list = F.softmax(y)[0].data
            class_index = np.random.choice(len(prob_list), p=prob_list)
            loss = F.log_softmax(y)[0, class_index]
            self.cleargrads()
            loss.backward()
            for i in range(len(self.variable_list)):
                fisher_accum_list[i] += np.square(
                        self.variable_list[i][1].grad)

        self.fisher_list = [
                F_accum / self.num_samples for F_accum in fisher_accum_list]
        return self.fisher_list
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号