broadcast.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:chainer-deconv 作者: germanRos 项目源码 文件源码
def _backward_one(x, g):
    if g is None:
        xp = cuda.get_array_module(x)
        return xp.zeros_like(x)

    if g.ndim != x.ndim:
        g = g.sum(axis=tuple(range(g.ndim - x.ndim)))
        # An input variable is always an array, not a scalar.
        # We need to convert a scalar value to a zero-dim array.
        xp = cuda.get_array_module(x)
        if xp.isscalar(g):
            g = xp.array(g)

    axis = tuple(i for i, sx in enumerate(x.shape) if sx == 1)
    if len(axis) > 0:
        return g.sum(keepdims=True, axis=axis)
    else:
        return g
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号