basic.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def grad(self, inp, grads):
        # The strict sense mathematical gradient of the maximum function is
        # not calculated here for it is not defined at every point where some
        # coordinates are identical. However, since the latter set has null
        # Lebesgue measure, the result may be interpreted as weak gradient.

        # @note: This function should work correctly for L{vector}s.
        # (x, y), (gz, gw)
        # gz*dz/dx + gw*dw/dx, gz*dz/dy + gw*dw/dy
        # gMax * dMax/dx + gArgMax * dArgMax/dx,
        # gMax * dMax/daxis + gArgMax * dArgMax/daxis
        # g_max has one less dimension than x, so you need to complete
        # g_max to x's shape when axis=0 the broadcasting mechanism
        # does it automatically
        x, axis = inp
        g_max, g_max_idx = grads

        g_max_disconnected = isinstance(g_max.type, DisconnectedType)
        g_max_idx_disconnected = isinstance(g_max_idx.type, DisconnectedType)

        # if the op is totally disconnected, so are its inputs
        if g_max_disconnected and g_max_idx_disconnected:
            return [DisconnectedType()(), DisconnectedType()()]

        axis_grad = grad_undefined(
            self, 1, axis,
            "argmax is not defined for non-integer axes so"
            " argmax(x, axis+eps) is undefined")

        # if the max is disconnected but the argmax is not,
        # the gradient on its inputs is zero
        if g_max_disconnected:
            return [x.zeros_like(), axis_grad]
        if NoneConst.equals(axis):
            axis_ = list(range(x.ndim))
        else:
            axis_ = axis
        xmax = max(x, axis_)

        # Raise the g_max and xmax to the same number of dim as the input.
        pattern = []
        out_dim = 0
        if NoneConst.equals(axis):
            # We are taking the max/argmax over all dimensions.
            axis = None
        for i in xrange(x.ndim):
            if axis is None or i in axis.data:
                pattern.append('x')
            else:
                pattern.append(out_dim)
                out_dim += 1
        g_max_pad = DimShuffle(g_max.broadcastable, pattern)(g_max)
        xmax_pad = DimShuffle(xmax.broadcastable, pattern)(xmax)

        # Set the grad to the correct position.
        g_x = eq(xmax_pad, x) * g_max_pad
        return g_x, axis_grad
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号