basic_ops.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def make_node(self, x, ilist):
        x_ = as_cuda_ndarray_variable(x)
        ilist_ = tensor.as_tensor_variable(ilist)
        if ilist_.type.dtype[:3] not in ('int', 'uin'):
            raise TypeError('index must be integers')
        if ilist_.type.ndim != 1:
            raise TypeError('index must be vector')
        if x_.type.ndim == 0:
            raise TypeError('cannot index into a scalar')

        # c code suppose it is int64
        if x.ndim in [1, 2, 3] and ilist_.dtype in [
                'int8', 'int16', 'int32', 'uint8', 'uint16', 'uint32']:
            ilist_ = tensor.cast(ilist_, 'int64')

        bcast = (ilist_.broadcastable[0],) + x_.broadcastable[1:]
        return Apply(self, [x_, ilist_],
                     [CudaNdarrayType(dtype=x.dtype,
                                      broadcastable=bcast)()])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号