def fisher_vec_z(ys, xs, vs):
"""Implements JJ'v, where v is on the output space.
Args:
ys: Loss function or output variables.
xs: Weights, list of tensors.
vs: List of tensors to multiply, for each weight tensor.
Returns:
JJ'v: Fisher vector product on the output space.
"""
# Validate the input
if type(ys) == list:
if len(vs) != len(ys):
raise ValueError("ys and vs must have the same length.")
jv = tf.gradients(ys, xs, vs, gate_gradients=True)
jjv = forward_gradients(ys, xs, jv, gate_gradients=True)
return jjv
评论列表
文章目录