def test_disconnected_cost_grad():
# Tests that if we say the cost is disconnected via the
# known_grads mechanism, it is treated as such by the rest of the
# system.
# This is so that Ops that are built around minigraphs like OpFromGraph
# and scan can implement Op.grad by passing ograds to known_grads
x = theano.tensor.iscalar()
y = theano.tensor.iscalar()
cost = x + y
assert cost.dtype in theano.tensor.discrete_dtypes
try:
theano.tensor.grad(cost, [x, y], known_grads={cost: gradient.DisconnectedType()()}, disconnected_inputs='raise')
except theano.gradient.DisconnectedInputError:
return
raise AssertionError("A disconnected gradient has been ignored.")
评论列表
文章目录