Tensor.py 文件源码

python
阅读 41 收藏 0 点赞 0 评论 0

项目:HamiltonianPy 作者: waltergu 项目源码 文件源码
def directsum(tensors,labels,axes=()):
        '''
        The directsum of a couple of tensors.

        Parameters
        ----------
        tensors : list of Tensor
            The tensors to be directsummed.
        labels : list of Label
            The labels of the directsum.
        axes : list of integer, optional
                The axes along which the directsum is block diagonal.

        Returns
        -------
        Tensor
            The directsum of the tensors.
        '''
        for i,tensor in enumerate(tensors):
            if i==0:
                assert tensor.ndim>len(axes)
                ndim,qnon,shps=tensor.ndim,tensor.qnon,[tensor.shape[axis] for axis in axes]
                alters,shape,dtypes=set(xrange(ndim))-set(axes),list(tensor.shape),[tensor.dtype]
            else:
                assert tensor.ndim==ndim and tensor.qnon==qnon and [tensor.shape[axis] for axis in axes]==shps
                for alter in alters: shape[alter]+=tensor.shape[alter]
                dtypes.append(tensor.dtype)
        data=np.zeros(tuple(shape),dtype=np.find_common_type([],dtypes))
        for i,tensor in enumerate(tensors):
            if i==0:
                slices=[slice(0,tensor.shape[axis]) if axis in alters else slice(None,None,None) for axis in xrange(ndim)]
            else:
                for alter in alters:
                    slices[alter]=slice(slices[alter].stop,slices[alter].stop+tensor.shape[alter])
            data[tuple(slices)]=tensor[...]
        if qnon:
            for alter in alters:
                labels[alter].qns=QuantumNumbers.union([tensor.labels[alter].qns for tensor in tensors])
            for axis in axes:
                labels[axis].qns=next(iter(tensor)).labels[axis].qns
        return Tensor(data,labels=labels)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号