def fold(unfolded_tensor, mode, shape):
"""Refolds the mode-`mode` unfolding into a tensor of shape `shape`
In other words, refolds the n-mode unfolded tensor
into the original tensor of the specified shape.
Parameters
----------
unfolded_tensor : ndarray
unfolded tensor of shape ``(shape[mode], -1)``
mode : int
the mode of the unfolding
shape : tuple
shape of the original tensor before unfolding
Returns
-------
ndarray
folded_tensor of shape `shape`
"""
full_shape = list(shape)
mode_dim = full_shape.pop(mode)
full_shape.insert(0, mode_dim)
return np.moveaxis(unfolded_tensor.reshape(full_shape), 0, mode)
评论列表
文章目录