torch_backend.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:ktorch 作者: farizrahman4u 项目源码 文件源码
def reshape(x, shape):
    def _reshape(x, shape=shape):
        return x.view(shape)

    def _compute_output_shape(x, shape=shape):
        if -1 not in shape:
            return shape
        else:
            n_elems = np.prod(list(_get_shape(x)))
            new_shape = list(shape)
            new_shape.remove(-1)
            new_axis = n_elems // np.prod(new_shape)
            s = list(shape)
            s[s.index(-1)] = new_axis
            return tuple(s)

    return get_op(_reshape, output_shape=_compute_output_shape, arguments=shape)(x)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号