subtensor.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def adv_index_broadcastable_pattern(a, idx):
    """
    This function is only used to determine the broadcast pattern for
    AdvancedSubtensor output variable.

    For this, we make a fake ndarray and a fake idx and call use ask numpy
    the output. From this, we find the output broadcast pattern.

    """

    def replace_slice(v):
        if isinstance(v, gof.Apply):
            if len(v.outputs) != 1:
                raise ValueError(
                    "It is ambiguous which output of a multi-output Op has"
                    " to be fetched.", v)
            else:
                v = v.outputs[0]

        if NoneConst.equals(v):
            return None
        if isinstance(v.type, SliceType):
            return slice(None, None)

        return numpy.zeros((2,) * v.ndim, int)

    newidx = tuple(map(replace_slice, idx))

    # 2 - True = 1; 2 - False = 2
    fakeshape = [2 - bc for bc in a.broadcastable]
    retshape = numpy.empty(fakeshape)[newidx].shape
    return tuple([dim == 1 for dim in retshape])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号