elemwise.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def _bgrad(self, inputs, ograds):
        # returns grad, with respect to broadcasted versions of inputs

        prev_setting = theano.config.compute_test_value

        try:

            theano.config.compute_test_value = 'off'

            def as_scalar(t):
                if isinstance(t.type, (NullType, DisconnectedType)):
                    return t
                return get_scalar_type(t.type.dtype)()

            scalar_inputs = list(map(as_scalar, inputs))
            scalar_ograds = list(map(as_scalar, ograds))
            scalar_igrads = self.scalar_op.grad(scalar_inputs, scalar_ograds)
            for igrad in scalar_igrads:
                assert igrad is not None, self.scalar_op

        finally:

            theano.config.compute_test_value = prev_setting

        if not isinstance(scalar_igrads, (list, tuple)):
            raise TypeError('%s.grad returned %s instead of list or tuple' %
                            (str(self.scalar_op), str(type(scalar_igrads))))

        nd = len(inputs[0].type.broadcastable)  # this is the same for everyone

        def transform(r):
            # From a graph of ScalarOps, make a graph of Broadcast ops.
            if isinstance(r.type, (NullType, DisconnectedType)):
                return r
            if r in scalar_inputs:
                return inputs[scalar_inputs.index(r)]
            if r in scalar_ograds:
                return ograds[scalar_ograds.index(r)]
            node = r.owner
            if node is None:
                # the gradient contains a constant, translate it as
                # an equivalent TensorType of size 1 and proper number of
                # dimensions
                res = theano.tensor.constant(numpy.asarray(r.data),
                                             dtype=r.type.dtype)
                return DimShuffle((), ['x'] * nd)(res)

            new_r = Elemwise(node.op, {})(
                *[transform(ipt) for ipt in node.inputs])
            return new_r
        ret = []
        for scalar_igrad, ipt in izip(scalar_igrads, inputs):
            if scalar_igrad is None:
                # undefined gradient
                ret.append(None)
                continue
            ret.append(transform(scalar_igrad))

        return ret
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号