def _lower_bound_grad(op, grad):
"""Gradient for `_lower_bound`.
Args:
op: the tensorflow op for which to calculate a gradient
grad: gradient with respect to the output of the op
Returns:
gradients with respect to the inputs of the op
"""
inputs = op.inputs[0]
bound = op.inputs[1]
pass_through_if = tf.logical_or(inputs >= bound, grad < 0)
return [tf.cast(pass_through_if, grad.dtype) * grad, None]
评论列表
文章目录