def _local_dot(ltens_l, ltens_r, axes):
"""Computes the local tensors of a dot product `dot(l, r)`.
Besides computing the normal dot product, this function rearranges the
virtual legs in such a way that the result is a valid local tensor again.
:param ltens_l: Array with `ndim > 1`
:param ltens_r: Array with `ndim > 1`
:param axes: Axes to compute dot product using the convention of
:func:`numpy.tensordot()`. Note that these correspond to the true
(and not the just the physical) legs of the local tensors
:returns: Correct local tensor representation
"""
# number of contracted legs need to be the same
clegs_l = len(axes[0]) if isinstance(axes[0], collections.Sequence) else 1
clegs_r = len(axes[1]) if isinstance(axes[0], collections.Sequence) else 1
assert clegs_l == clegs_r, \
"Number of contracted legs differ: {} != {}".format(clegs_l, clegs_r)
res = np.tensordot(ltens_l, ltens_r, axes=axes)
# Rearrange the virtual-dimension legs
res = np.rollaxis(res, ltens_l.ndim - clegs_l, 1)
res = np.rollaxis(res, ltens_l.ndim - clegs_l,
ltens_l.ndim + ltens_r.ndim - clegs_l - clegs_r - 1)
return res.reshape((ltens_l.shape[0] * ltens_r.shape[0], ) +
res.shape[2:-2] +
(ltens_l.shape[-1] * ltens_r.shape[-1],))
评论列表
文章目录