ops.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:dnc-theano 作者: khaotik 项目源码 文件源码
def op_matmul(s_x_, s_y_, axes_=(-2, -1)):
    '''
    limited implementation of np.matmul, does not support broadcasting

    Args:
        s_x_: (batch of) matrix(matrices)
        s_y_: (batch of) matrix(matrices)
        axes_: tuple of int, the axes for the matrix
    '''
    assert s_x_.ndim == s_y_.ndim
    ndim = s_x_.ndim
    assert -ndim <= axes_[0] < ndim
    assert -ndim <= axes_[1] < ndim
    assert ndim >= 2
    axes = axes_[0]%ndim, axes_[1]%ndim
    if ndim == 2:
        if axes == (0,1):
            return T.dot(s_x_, s_y_)
        else:
            return T.dot(s_y_, s_x_)
    s_shp = T.shape(s_x_)
    s_size = reduce(T.mul, [s_shp[i] for i in range(s_x_.ndim) if i not in axes])
    s_szu = s_shp[axes[0]]
    s_szv = s_shp[axes[1]]
    s_szw = T.shape(s_y_)[axes[1]]
    transpp = list(range(ndim))
    transpp[axes[0]], transpp[ndim-2] = transpp[ndim-2], transpp[axes[0]]
    transpp[axes[1]], transpp[ndim-1] = transpp[ndim-1], transpp[axes[1]]
    s_shp2 = [s_shp[a] for a in transpp]
    s_shp2[axes[1]] = s_szw
    s_x = s_x_.transpose(*transpp).reshape((s_size, s_szu, s_szv))
    s_y = s_y_.transpose(*transpp).reshape((s_size, s_szv, s_szw))
    return T.batched_dot(s_x, s_y).reshape(s_shp2).transpose(transpp)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号