axes.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def reshape_workaround(data, shape_out):  # type: (TensorOp, Sequence[int]) -> TensorOp
    """Limited workaround for tensor reshape operation."""
    shape_in = data.shape.lengths

    if np.prod(shape_in) != np.prod(shape_out):
        raise ValueError('Total size of input (%d) and output (%d) dimension mismatch.',
                         np.prod(shape_in), np.prod(shape_out))

    ndims_out = len(shape_out)
    if ndims_out == 1:
        tensor = ng.flatten(data)
    elif ndims_out == 2:
        cumprods = list(np.cumprod(shape_in))
        flatten_at_idx = cumprods.index(shape_out[0]) + 1
        tensor = ng.flatten_at(data, flatten_at_idx)
    else:
        raise NotImplementedError('Reshape can only support flatten to 1d or 2d.')

    return ng.cast_axes(tensor, make_pos_axes(shape_out))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号