def _augment_module_post(net: nn.Module, callback_dict: dict) -> (dict, list):
backward_hook_remove_func_list = []
vis_param_dict = dict()
vis_param_dict['layer'] = None
vis_param_dict['index'] = None
vis_param_dict['method'] = GradType.NAIVE
for x, y in net.named_modules():
if not isinstance(y, nn.Sequential) and y is not net:
# I should add hook to all layers, in case they will be needed.
backward_hook_remove_func_list.append(
y.register_backward_hook(
partial(_backward_hook, module_name=x, callback_dict=callback_dict, vis_param_dict=vis_param_dict)))
def remove_handles():
for x in backward_hook_remove_func_list:
x.remove()
return vis_param_dict, remove_handles
评论列表
文章目录