math_grad.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:complex_tf 作者: woodshop 项目源码 文件源码
def _CplxMatMulGrad(op, grad):
  inp0 = tf.conj(op.inputs[0])
  inp1 = tf.conj(op.inputs[1])
  t_a = op.get_attr("transpose_a")
  t_b = op.get_attr("transpose_b")
  if not t_a and not t_b:
    return (math_ops.matmul(
        grad, inp1, transpose_b=True), math_ops.matmul(
            inp0, grad, transpose_a=True))
  elif not t_a and t_b:
    return (math_ops.matmul(grad, inp1), math_ops.matmul(
        grad, inp0, transpose_a=True))
  elif t_a and not t_b:
    return (math_ops.matmul(
        inp1, grad, transpose_b=True),
            math_ops.matmul(inp0, grad))
  elif t_a and t_b:
    return (math_ops.matmul(
        inp1, grad, transpose_a=True, transpose_b=True),
            math_ops.matmul(
                grad, inp0, transpose_a=True, transpose_b=True))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号