def reshape_workaround(data, shape_out): # type: (TensorOp, Sequence[int]) -> TensorOp
"""Limited workaround for tensor reshape operation."""
shape_in = data.shape.lengths
if np.prod(shape_in) != np.prod(shape_out):
raise ValueError('Total size of input (%d) and output (%d) dimension mismatch.',
np.prod(shape_in), np.prod(shape_out))
ndims_out = len(shape_out)
if ndims_out == 1:
tensor = ng.flatten(data)
elif ndims_out == 2:
cumprods = list(np.cumprod(shape_in))
flatten_at_idx = cumprods.index(shape_out[0]) + 1
tensor = ng.flatten_at(data, flatten_at_idx)
else:
raise NotImplementedError('Reshape can only support flatten to 1d or 2d.')
return ng.cast_axes(tensor, make_pos_axes(shape_out))
评论列表
文章目录