def norms_of_d_dynamics_d_hypers(fd=None):
"""
In `ForwardHG` records the norm of the partial derivatives of the dynamics w.r.t. the hyperparameters.
:param fd:
:return:
"""
if fd is None: fd = lambda stp, rs: rs
def _call(*args, **kwargs):
hg = args[0]
if isinstance(hg, rf.HyperOptimizer):
hg = hg.hyper_gradients # guess most common case
assert isinstance(hg, rf.ForwardHG)
_rs = Records.tensors(*hg.d_dynamics_d_hypers, op=tf.norm,
fd=fd,
condition=lambda stp, rs: rs != 'INIT')(args, kwargs)
return _rs
return _call
评论列表
文章目录