def gauss_newton_vec(ys, zs, xs, vs):
"""Implements Gauss-Newton vector product.
Args:
ys: Loss function.
zs: Before output layer (input to softmax).
xs: Weights, list of tensors.
vs: List of perturbation vector for each weight tensor.
Returns:
J'HJv: Guass-Newton vector product.
"""
# Validate the input
if type(xs) == list:
if len(vs) != len(xs):
raise ValueError("xs and vs must have the same length.")
grads_z = tf.gradients(ys, zs, gate_gradients=True)
hjv = forward_gradients(grads_z, xs, vs, gate_gradients=True)
jhjv = tf.gradients(zs, xs, hjv, gate_gradients=True)
return jhjv, hjv
评论列表
文章目录