def hessian_vec_bk(ys, xs, vs, grads=None):
"""Implements Hessian vector product using backward on backward AD.
Args:
ys: Loss function.
xs: Weights, list of tensors.
vs: List of tensors to multiply, for each weight tensor.
Returns:
Hv: Hessian vector product, same size, same shape as xs.
"""
# Validate the input
if type(xs) == list:
if len(vs) != len(xs):
raise ValueError("xs and vs must have the same length.")
if grads is None:
grads = tf.gradients(ys, xs, gate_gradients=True)
return tf.gradients(grads, xs, vs, gate_gradients=True)
评论列表
文章目录