def fisher_vec_bk(ys, xs, vs):
"""Implements Fisher vector product using backward AD.
Args:
ys: Loss function, scalar.
xs: Weights, list of tensors.
vs: List of tensors to multiply, for each weight tensor.
Returns:
J'Jv: Fisher vector product.
"""
# Validate the input
if type(xs) == list:
if len(vs) != len(xs):
raise ValueError("xs and vs must have the same length.")
grads = tf.gradients(ys, xs, gate_gradients=True)
gradsv = list(map(lambda x: tf.reduce_sum(x[0] * x[1]), zip(grads, vs)))
jv = tf.add_n(gradsv)
jjv = list(map(lambda x: x * jv, grads))
return jjv
评论列表
文章目录