second_order.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:tensorflow-forward-ad 作者: renmengye 项目源码 文件源码
def fisher_vec_bk(ys, xs, vs):
  """Implements Fisher vector product using backward AD.

  Args:
    ys: Loss function, scalar.
    xs: Weights, list of tensors.
    vs: List of tensors to multiply, for each weight tensor.

  Returns:
    J'Jv: Fisher vector product.
  """
  # Validate the input
  if type(xs) == list:
    if len(vs) != len(xs):
      raise ValueError("xs and vs must have the same length.")

  grads = tf.gradients(ys, xs, gate_gradients=True)
  gradsv = list(map(lambda x: tf.reduce_sum(x[0] * x[1]), zip(grads, vs)))
  jv = tf.add_n(gradsv)
  jjv = list(map(lambda x: x * jv, grads))
  return jjv
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号