network.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:EWC 作者: stokesj 项目源码 文件源码
def create_fisher_ops(self):
        self.fisher_diagonal = self.bias_shaped_variables(name='bias_grads2', c=0.0, trainable=False) +\
                               self.weight_shaped_variables(name='weight_grads2', c=0.0, trainable=False)

        self.fisher_accumulate_op = [tf.assign_add(f1, f2) for f1, f2 in zip(self.fisher_diagonal, self.fisher_minibatch)]
        scale = 1 / float(self.ewc_batches * self.ewc_batch_size)
        self.fisher_full_batch_average_op = [tf.assign(var, scale * var) for var in self.fisher_diagonal]
        self.fisher_zero_op = [tf.assign(tensor, tf.zeros_like(tensor)) for tensor in self.fisher_diagonal]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号