basic.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def grad(self, axis_and_tensors, grads):
        """ The gradient wrt a join op is a `Split`, used to partition
        the gradient along the `axis` which was used for joining.
        """
        gz, = grads
        axis, tensors = axis_and_tensors[0], axis_and_tensors[1:]

        rval = [grad_undefined(self, 0, axis)]

        dtypes = [as_tensor_variable(x).type.dtype for x in tensors]
        out_dtype = scal.upcast(*dtypes)

        if 'float' in out_dtype or 'complex' in out_dtype:
            # assume that this is differentiable
            split = Split(len(tensors))
            split_gz = split(gz, axis, stack([shape(x)[axis]
                                              for x in tensors]))
            # If there is only one split, it might not be in a list.
            if not isinstance(split_gz, list):
                split_gz = [split_gz]
            # Split.make_node isn't always able to infer the right
            # broadcast. As the grad need to keep the information,
            # read it if needed.
            split_gz = [patternbroadcast(g, t.broadcastable)
                        for t, g in zip(tensors, split_gz)]
            rval = rval + split_gz
        else:
            # the output has integer type, so the gradient through it
            # is 0
            rval = rval + [tensor.zeros_like(dtype=config.floatX)
                           for tensor in tensors]

        return rval
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号