nonlinear_expansion.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:cuicuilco 作者: AlbertoEsc 项目源码 文件源码
def pairwise_expansion(x, func, reflexive=True):
    """Computes func(xi, xj) over all possible indices i and j, where func is an arbitrary function
    if reflexive == False, only pairs with i != j are considered
    """
    x_height, x_width = x.shape
    if reflexive:
        k = 0
    else:
        k = 1
    mask = numpy.triu(numpy.ones((x_width, x_width)), k) > 0.5
    #    mask = mask.reshape((1,x_width,x_width))
    y1 = x.reshape(x_height, x_width, 1)
    y2 = x.reshape(x_height, 1, x_width)
    yexp = func(y1, y2)

    #    print "yexp.shape=", yexp.shape
    #    print "mask.shape=", mask.shape
    out = yexp[:, mask]
    #    print "out.shape=", out.shape
    # yexp.reshape((x_height, N*N))
    return out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号