def compute_fisher(self, dataset):
fisher_accum_list = [
np.zeros(var[1].shape) for var in self.variable_list]
for _ in range(self.num_samples):
x, _ = dataset[np.random.randint(len(dataset))]
y = self.predictor(np.array([x]))
prob_list = F.softmax(y)[0].data
class_index = np.random.choice(len(prob_list), p=prob_list)
loss = F.log_softmax(y)[0, class_index]
self.cleargrads()
loss.backward()
for i in range(len(self.variable_list)):
fisher_accum_list[i] += np.square(
self.variable_list[i][1].grad)
self.fisher_list = [
F_accum / self.num_samples for F_accum in fisher_accum_list]
return self.fisher_list
评论列表
文章目录