type.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def filter_variable(self, other, allow_convert=True):
        from theano.gpuarray.basic_ops import gpu_from_host

        if hasattr(other, '_as_GpuArrayVariable'):
            other = other._as_GpuArrayVariable(self.context_name)

        if not isinstance(other, Variable):
            other = self.Constant(type=self, data=other)

        if other.type == self:
            return other

        if not isinstance(other.type, tensor.TensorType):
            raise TypeError('Incompatible type', (self, other.type))
        if (other.type.dtype != self.dtype):
            raise TypeError('Incompatible dtype', (self.dtype,
                                                   other.type.dtype))
        if other.type.ndim != self.ndim:
            raise TypeError('Incompatible number of dimensions.'
                            ' Expected %d, got %d.' % (self.ndim, other.ndim))
        if other.type.broadcastable != self.broadcastable:
            if allow_convert:
                type2 = other.type.clone(broadcastable=self.broadcastable)
                other2 = type2.convert_variable(other)
            else:
                other2 = None
            if other2 is None:
                raise TypeError('Incompatible broadcastable dimensions.'
                                ' Expected %s, got %s.' %
                                (str(other.type.broadcastable),
                                 str(self.broadcastable)))
            other = other2

        return gpu_from_host(self.context_name)(other)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号