nnet.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def local_useless_crossentropy_softmax_1hot_with_bias_dx_alloc(node):
    """
    Replace a CrossentropySoftmax1HotWithBiasDx op, whose incoming gradient is
    an `alloc` of a scalar variable or one that has either broadcastable or
    matching dimensions with the output variable, by one that skips the
    intermediate `alloc`.

    """
    if isinstance(node.op, CrossentropySoftmax1HotWithBiasDx):
        dy, sm, y_idx = node.inputs

        # Those cases are directly handled by the internal broadcasting of the
        # `CrossentropySoftmax1HotWithBiasDx` op.
        if dy.ndim == 0:
            return False
        if dy.ndim == 1 and dy.broadcastable[0]:
            return False

        assert dy.ndim == 1

        if dy.owner is not None and isinstance(dy.owner.op, tensor.Alloc):
            # dz is the input of the Alloc op, i.e. T.alloc(dz, <shape>)
            dz = dy.owner.inputs[0]

            try:
                shape_feature = node.fgraph.shape_feature
            except AttributeError:
                # The shape feature may not be available in some mode, but we
                # need it for this optimization, so don't continue.
                return False

            shape_of = shape_feature.shape_of
            same_shape = shape_feature.same_shape

            # Build `dz_broad` explicitly to include extra implicit dimensions.
            dz_broad = (True,) * (dy.ndim - dz.ndim) + dz.broadcastable

            # If we can infer statically that the shape of `sm` and
            # `dy` are the same in dimension `k` or the shape of `dy` is equal
            # to 1 (which triggers the internal broadcasting in
            # `CrossentropySoftmax1HotWithBiasDx`) we do not need to
            # check it at runtime.
            if (dz_broad[0] and
                    not same_shape(sm, dy, dim_x=0, dim_y=0) and
                    shape_of[dy][0] != 1):
                # If `dz` is broadcastable, we need to check whether the shapes
                # of `dy` and `sm` are the same or whether the shape of `dy` is
                # equal to 1.
                cond = tensor.or_(tensor.eq(dy.shape[0], 1),
                                  tensor.eq(dy.shape[0], sm.shape[0]))
                msg = '`sm` and `dy` do not have the same shape.'
                dz = opt.Assert(msg)(dz, cond)

            ret = node.op(dz, sm, y_idx)
            copy_stack_trace(node.outputs[0], ret)
            return [ret]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号