def _bgrad(self, inputs, ograds):
# returns grad, with respect to broadcasted versions of inputs
prev_setting = theano.config.compute_test_value
try:
theano.config.compute_test_value = 'off'
def as_scalar(t):
if isinstance(t.type, (NullType, DisconnectedType)):
return t
return get_scalar_type(t.type.dtype)()
scalar_inputs = list(map(as_scalar, inputs))
scalar_ograds = list(map(as_scalar, ograds))
scalar_igrads = self.scalar_op.grad(scalar_inputs, scalar_ograds)
for igrad in scalar_igrads:
assert igrad is not None, self.scalar_op
finally:
theano.config.compute_test_value = prev_setting
if not isinstance(scalar_igrads, (list, tuple)):
raise TypeError('%s.grad returned %s instead of list or tuple' %
(str(self.scalar_op), str(type(scalar_igrads))))
nd = len(inputs[0].type.broadcastable) # this is the same for everyone
def transform(r):
# From a graph of ScalarOps, make a graph of Broadcast ops.
if isinstance(r.type, (NullType, DisconnectedType)):
return r
if r in scalar_inputs:
return inputs[scalar_inputs.index(r)]
if r in scalar_ograds:
return ograds[scalar_ograds.index(r)]
node = r.owner
if node is None:
# the gradient contains a constant, translate it as
# an equivalent TensorType of size 1 and proper number of
# dimensions
res = theano.tensor.constant(numpy.asarray(r.data),
dtype=r.type.dtype)
return DimShuffle((), ['x'] * nd)(res)
new_r = Elemwise(node.op, {})(
*[transform(ipt) for ipt in node.inputs])
return new_r
ret = []
for scalar_igrad, ipt in izip(scalar_igrads, inputs):
if scalar_igrad is None:
# undefined gradient
ret.append(None)
continue
ret.append(transform(scalar_igrad))
return ret
评论列表
文章目录