tensor_utils.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:tensortools 作者: ahwillia 项目源码 文件源码
def fold(unfolded_tensor, mode, shape):
    """Refolds the mode-`mode` unfolding into a tensor of shape `shape`
        In other words, refolds the n-mode unfolded tensor
        into the original tensor of the specified shape.
    Parameters
    ----------
    unfolded_tensor : ndarray
        unfolded tensor of shape ``(shape[mode], -1)``
    mode : int
        the mode of the unfolding
    shape : tuple
        shape of the original tensor before unfolding
    Returns
    -------
    ndarray
        folded_tensor of shape `shape`
    """
    full_shape = list(shape)
    mode_dim = full_shape.pop(mode)
    full_shape.insert(0, mode_dim)
    return np.moveaxis(unfolded_tensor.reshape(full_shape), 0, mode)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号