torch_backend.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:ktorch 作者: farizrahman4u 项目源码 文件源码
def dot(x, y):
    def _dot(X):
        x, y = X
        x_ndim = ndim(x)
        y_ndim = ndim(y)
        if x_ndim == 2 and y_ndim == 2:
            return torch.mm(x, y)
        if x_ndim == 2 and y_ndim == 1:
            return torch.mv(x, y)
        if x_ndim == 1 and y_ndim == 2:
            return torch.mv(y, x)
        if x_ndim == 1 and y_ndim == 1:
            return torch.dot(x, y)
        else:
            raise Exception('Unsupported tensor ranks for dot operation : ' + str(x_ndim) + ' and ' + str(y_ndim) + '.')

    def _compute_output_shape(X):
        x, y = _get_shape(X[0]), _get_shape(X[1])
        x_ndim = len(x)
        y_ndim = len(y)
        if x_ndim == 2 and y_ndim == 2:
            return (x[0], y[1])
        if x_ndim == 2 and y_ndim == 1:
            return (x[0],)
        if x_ndim == 1 and y_ndim == 2:
            return (y[0],)
        if x_ndim == 1 and y_ndim == 1:
            return (0,)

    return get_op(_dot, output_shape=_compute_output_shape)([x, y])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号