def _set_hook_func(self):
def func_b(module, grad_in, grad_out):
self.all_grads[id(module)] = grad_in[0].cpu()
# Cut off negative gradients
if isinstance(module, nn.ReLU):
return (torch.clamp(grad_in[0], min=0.0),)
for module in self.model.named_modules():
module[1].register_backward_hook(func_b)
评论列表
文章目录