def pairwise_expansion(x, func, reflexive=True):
"""Computes func(xi, xj) over all possible indices i and j, where func is an arbitrary function
if reflexive == False, only pairs with i != j are considered
"""
x_height, x_width = x.shape
if reflexive:
k = 0
else:
k = 1
mask = numpy.triu(numpy.ones((x_width, x_width)), k) > 0.5
# mask = mask.reshape((1,x_width,x_width))
y1 = x.reshape(x_height, x_width, 1)
y2 = x.reshape(x_height, 1, x_width)
yexp = func(y1, y2)
# print "yexp.shape=", yexp.shape
# print "mask.shape=", mask.shape
out = yexp[:, mask]
# print "out.shape=", out.shape
# yexp.reshape((x_height, N*N))
return out
评论列表
文章目录