def tensor_function(self, tensor):
tensor = np.asarray(tensor)
if tensor.ndim == 3:
# There's a channel axis - we move it to front
tensor = np.moveaxis(tensor, source=-1, destination=0)
elif tensor.ndim == 2:
pass
else:
raise NotImplementedError("Expected tensor to be a 2D or 3D "
"numpy array, got a {}D array instead."
.format(tensor.ndim))
return tensor
评论列表
文章目录