def directsum(tensors,labels,axes=()):
'''
The directsum of a couple of tensors.
Parameters
----------
tensors : list of Tensor
The tensors to be directsummed.
labels : list of Label
The labels of the directsum.
axes : list of integer, optional
The axes along which the directsum is block diagonal.
Returns
-------
Tensor
The directsum of the tensors.
'''
for i,tensor in enumerate(tensors):
if i==0:
assert tensor.ndim>len(axes)
ndim,qnon,shps=tensor.ndim,tensor.qnon,[tensor.shape[axis] for axis in axes]
alters,shape,dtypes=set(xrange(ndim))-set(axes),list(tensor.shape),[tensor.dtype]
else:
assert tensor.ndim==ndim and tensor.qnon==qnon and [tensor.shape[axis] for axis in axes]==shps
for alter in alters: shape[alter]+=tensor.shape[alter]
dtypes.append(tensor.dtype)
data=np.zeros(tuple(shape),dtype=np.find_common_type([],dtypes))
for i,tensor in enumerate(tensors):
if i==0:
slices=[slice(0,tensor.shape[axis]) if axis in alters else slice(None,None,None) for axis in xrange(ndim)]
else:
for alter in alters:
slices[alter]=slice(slices[alter].stop,slices[alter].stop+tensor.shape[alter])
data[tuple(slices)]=tensor[...]
if qnon:
for alter in alters:
labels[alter].qns=QuantumNumbers.union([tensor.labels[alter].qns for tensor in tensors])
for axis in axes:
labels[axis].qns=next(iter(tensor)).labels[axis].qns
return Tensor(data,labels=labels)
评论列表
文章目录