fwgrad.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:tensorflow-forward-ad 作者: renmengye 项目源码 文件源码
def Pack_FwGrad(*args, **kwargs):
  dx = args[1:]
  axis = kwargs["axis"]
  if all(map(lambda x: x is None, dx)):
    log.error("hey")
    return None
  else:
    ### Here we need to fill in zeros.
    def _mapper(_):
      dx = _[0]
      x = _[1]
      return dx if dx is not None else tf.zeros_like(x)

    dx = list(map(_mapper, zip(dx, list(args[0].inputs))))
    if tf.__version__.startswith("0"):
      return tf.pack(dx, axis=axis)
    else:
      return tf.stack(dx, axis=axis)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号